In [None]:
from typing import Optional

import torch
from lightning import pytorch as pl

from train.architectures import Architecture
from train.callbacks import ModelCheckpoint, PsdPlotter
from train.metrics import OnlinePsdRatio, PsdRatio

Tensor = torch.Tensor


class DeepClean(pl.LightningModule):
    def __init__(
        self,
        arch: Architecture,
        loss: PsdRatio,
        metric: OnlinePsdRatio,
        patience: Optional[int] = None,
        save_top_k_models: int = 10,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore=["arch", "loss", "metric"])

        self.model = arch
        self.loss = loss
        self.metric = metric
        self.metric.loss_fn = self.loss

    def forward(self, X: Tensor) -> Tensor:
        return self.model(X)

    def training_step(self, batch: tuple[Tensor, Tensor]) -> Tensor:
        X, y_true = batch
        y_pred = self(X)
        loss = self.loss(y_pred, y_true).mean()
        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return loss

    def _shared_eval_step(self, X, y_true) -> None:
        """
        Note that the actual computation of the loss function
        happens via the PsdPlotter Callback
        """

        if y_true is not None:
            self.metric.update(y_true[:, 0], "strain")
        if X is not None:
            y_pred = self(X)
            self.metric.update(y_pred, "predictions")

    def validation_step(self, batch, _) -> None:
        return self._shared_eval_step(*batch)

    def test_step(self, batch, _) -> None:
        return self._shared_eval_step(*batch)

    def configure_callbacks(self) -> list[pl.Callback]:
        # first callback actually computes all of our
        # validation metrics and any associated plots
        callbacks = [PsdPlotter()]

        # then tack on a checkpointer that uses these
        # metrcis for checkpointing the model
        checkpoint = ModelCheckpoint(
            monitor="val_loss",
            save_top_k=self.hparams.save_top_k_models,
            save_last=True,
            auto_insert_metric_name=False,
            mode="max",
        )
        callbacks.append(checkpoint)

        # if we specified an early-stopping patience
        # interval, add early stopping
        if self.hparams.patience is not None:
            early_stop = pl.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=self.hparams.patience,
                mode="min",
                min_delta=0.00,
            )
            callbacks.append(early_stop)
        return callbacks

In [None]:
# Assume these are the initialization arguments for your components
# You will need to adjust these based on your actual implementation
architecture_args = {}  # Add your architecture arguments here
psd_ratio_args = {}  # Add your PsdRatio arguments here
online_psd_ratio_args = {}  # Add your OnlinePsdRatio arguments here

# Instantiate your components with the appropriate arguments
arch = Architecture(**architecture_args)
loss = PsdRatio(**psd_ratio_args)
metric = OnlinePsdRatio(**online_psd_ratio_args)

In [None]:
# Path to your checkpoint
checkpoint_path = '/home/shuwei.yeh/deepclean/results/K1_train_test/lightning_logs/version_34/checkpoints/last.ckpt'

# Load the model from the checkpoint, providing the required components
model = DeepClean.load_from_checkpoint(
    checkpoint_path,
    arch=arch,
    loss=loss,
    metric=metric
)

In [None]:
import os
import torch
from train.architectures import Architecture  # Ensure this is correctly imported based on your project structure
from train.metrics import OnlinePsdRatio, PsdRatio  # Same here
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

def load_model(checkpoint_path, arch, loss, metric):
    """
    Load a DeepClean model from a checkpoint.
    
    Args:
    - checkpoint_path (str): Path to the checkpoint file.
    - arch (Architecture): The architecture of the model.
    - loss (PsdRatio): The loss function used in the model.
    - metric (OnlinePsdRatio): The metric used in the model.
    
    Returns:
    - DeepClean: The loaded model.
    """
    model = DeepClean.load_from_checkpoint(checkpoint_path=checkpoint_path, arch=arch, loss=loss, metric=metric)
    return model

def validate_model(model, dataloader):
    """
    Validate the model on a given DataLoader.
    
    Args:
    - model (DeepClean): The model to validate.
    - dataloader (DataLoader): DataLoader for validation data.
    
    Returns:
    - float: The validation loss.
    """
    trainer = Trainer()
    result = trainer.validate(model, dataloaders=dataloader)
    return result

# Specify the path to your checkpoint here
checkpoint_path = '/home/shuwei.yeh/deepclean/results/K1_train_test/lightning_logs/version_34/checkpoints/last.ckpt'

# You need to initialize `arch`, `loss`, and `metric` with proper arguments
# This is placeholder code; you'll need to fill in with actual initializations
arch = Architecture()  # Initialize your architecture here
loss = PsdRatio()  # Initialize your loss function here
metric = OnlinePsdRatio()  # Initialize your metric here

# Load your model
model = load_model(checkpoint_path, arch, loss, metric)

# Assuming you have a DataLoader ready for validation
# validation_dataloader = DataLoader(...)  # Initialize your DataLoader

# Validate the model
# Note: You need to uncomment the following line after setting up `validation_dataloader`
# validation_loss = validate_model(model, validation_dataloader)
# print(f"Validation Loss: {validation_loss}")


In [None]:
# Corrected instantiation for PsdRatio
sample_rate = 4096  # Example sample rate in Hz
fftlength = 2  # Example FFT length in seconds
freq_low = [55]  # Example lower frequency bounds
freq_high = [65]  # Example upper frequency bounds
loss = PsdRatio(sample_rate=sample_rate, fftlength=fftlength, freq_low=freq_low, freq_high=freq_high)

# Assuming Architecture, bandpass filter, and y_scaler are correctly initialized
arch = Architecture()  # Fill with actual initialization

# Corrected instantiation for OnlinePsdRatio
inference_sampling_rate = 64  # Same as sample rate for simplicity
edge_pad = 0.25  # Example edge padding in seconds
filter_pad = 0.25  # Example filter padding in seconds
# Assuming `bandpass` is a callable for bandpass filtering and `y_scaler` is an instance of a scaling module
bandpass_callable = lambda x: x  # Dummy bandpass callable, replace with actual
y_scaler_module = torch.nn.Identity()  # Dummy scaler, replace with actual
metric = OnlinePsdRatio(
    inference_sampling_rate=inference_sampling_rate,
    edge_pad=edge_pad,
    filter_pad=filter_pad,
    sample_rate=sample_rate,
    bandpass=bandpass_callable,
    y_scaler=y_scaler_module
)


In [None]:
import os

import h5py
import torch
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

from train.plotting import plot_psds
from utils.plotting.utils import save

from typing import Optional

import torch
from lightning import pytorch as pl

from train.architectures import Architecture
from train.callbacks import ModelCheckpoint, PsdPlotter
from train.metrics import OnlinePsdRatio, PsdRatio

In [None]:
class PsdPlotter(Callback):
    def on_fit_start(self, trainer, pl_module):
        log_dir = trainer.logger.log_dir or trainer.logger.save_dir

        # TODO: support s3 here
        self.plot_dir = os.path.join(log_dir, "plots")
        os.makedirs(self.plot_dir, exist_ok=True)

    def on_test_start(self, trainer, pl_module):
        log_dir = trainer.logger.log_dir or trainer.logger.save_dir
        self.test_dir = os.path.join(log_dir, "test")
        os.makedirs(self.test_dir, exist_ok=True)

    def log_plots(self, layout, fname, trainer):
        # always save the plots locally
        save(layout, fname, title="DeepClean PSDs")

        # if using W&B, log the plots as artifacts
        if isinstance(trainer.logger, WandbLogger):
            import wandb

            key = os.path.basename(fname).split("-")[0]
            html = wandb.Html(fname)
            trainer.logger.log_table(
                "samples", columns=[f"{key}-psds"], data=[[html]]
            )

    def _shared_eval(self, pl_module):
        # use our metric to produce the online-cleaned
        # noise prediction and strain timeseries, calling
        # compute to handle any distributed-training related
        # aggregation, then compute our loss functions on
        # these timeseries and log the output
        noise, strain = pl_module.metric.compute(reduce=False)
        pl_module.metric.reset()

        spectral_density = pl_module.loss.spectral_density
        fftlength = spectral_density.nperseg / pl_module.loss.sample_rate
        p = plot_psds(
            noise,
            strain,
            pl_module.loss.mask,
            spectral_density,
            fftlength,
            pl_module.loss.asd,
        )
        return noise, strain, p

    def on_validation_epoch_end(self, trainer, pl_module):
        # use our metric to produce the online-cleaned
        # noise prediction and strain timeseries, calling
        # compute to handle any distributed-training related
        # aggregation, then compute our loss functions on
        # these timeseries and log the output
        noise, strain, p = self._shared_eval(pl_module)
        loss = pl_module.loss(noise, strain)
        pl_module.log(
            "val_loss",
            loss,
            on_epoch=True,
            sync_dist=True,
            logger=True,
            prog_bar=True,
        )

        # use these timeseries to plot their ASDs
        # as well as their ratios
        step = str(trainer.global_step).zfill(5)
        fname = f"val-psds_step-{step}.html"
        fname = os.path.join(self.plot_dir, fname)
        self.log_plots(p, fname, trainer)

    def on_test_epoch_end(self, trainer, pl_module):
        noise, strain, p = self._shared_eval(pl_module)
        loss = pl_module.loss(noise, strain)
        pl_module.log(
            "test_loss", loss, on_epoch=True, sync_dist=True, logger=True
        )

        fname = os.path.join(self.test_dir, "test-psds.html")
        self.log_plots(p, fname, trainer)

        fname = os.path.join(self.test_dir, "outputs.hdf5")
        with h5py.File(fname, "w") as f:
            f["noise"] = noise.cpu().numpy()
            f["strain"] = strain.cpu().numpy()


class ModelCheckpoint(ModelCheckpoint):
    def on_train_end(self, trainer, pl_module):
        module = pl_module.__class__.load_from_checkpoint(
            self.best_model_path,
            arch=pl_module.model,
            metric=pl_module.metric,
            loss=pl_module.loss,
        )

        # TODO: we should probably establish an explicit
        # validation_kernel_length that matches what
        # we use at test time. If there were any issues
        # with it, we would have caught it by now during
        # validation, but worth making it explicit that
        # these are different values.
        datamodule = trainer.datamodule
        kernel_size = int(
            datamodule.hparams.kernel_length * datamodule.sample_rate
        )

        num_witnesses = len(datamodule.witness_channels)
        sample_input = torch.randn(1, num_witnesses, kernel_size)
        model = module.model.to("cpu")
        trace = torch.jit.trace(model, sample_input)

        save_dir = trainer.logger.log_dir or trainer.logger.save_dir
        if save_dir.startswith("s3://"):
            import s3fs

            s3 = s3fs.S3FileSystem()
            with s3.open(f"{save_dir}/model.pt", "wb") as f:
                torch.jit.save(trace, f)
        else:
            with open(os.path.join(save_dir, "model.pt"), "wb") as f:
                torch.jit.save(trace, f)

In [None]:
Tensor = torch.Tensor


class DeepClean(pl.LightningModule):
    def __init__(
        self,
        arch: Architecture,
        loss: PsdRatio,
        metric: OnlinePsdRatio,
        patience: Optional[int] = None,
        save_top_k_models: int = 10,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore=["arch", "loss", "metric"])

        self.model = arch
        self.loss = loss
        self.metric = metric
        self.metric.loss_fn = self.loss

    def forward(self, X: Tensor) -> Tensor:
        return self.model(X)

    def training_step(self, batch: tuple[Tensor, Tensor]) -> Tensor:
        X, y_true = batch
        y_pred = self(X)
        loss = self.loss(y_pred, y_true).mean()
        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return loss

    def _shared_eval_step(self, X, y_true) -> None:
        """
        Note that the actual computation of the loss function
        happens via the PsdPlotter Callback
        """

        if y_true is not None:
            self.metric.update(y_true[:, 0], "strain")
        if X is not None:
            y_pred = self(X)
            self.metric.update(y_pred, "predictions")

    def validation_step(self, batch, _) -> None:
        return self._shared_eval_step(*batch)

    def test_step(self, batch, _) -> None:
        return self._shared_eval_step(*batch)

    def configure_callbacks(self) -> list[pl.Callback]:
        # first callback actually computes all of our
        # validation metrics and any associated plots
        callbacks = [PsdPlotter()]

        # then tack on a checkpointer that uses these
        # metrcis for checkpointing the model
        checkpoint = ModelCheckpoint(
            monitor="val_loss",
            save_top_k=self.hparams.save_top_k_models,
            save_last=True,
            auto_insert_metric_name=False,
            mode="max",
        )
        callbacks.append(checkpoint)

        # if we specified an early-stopping patience
        # interval, add early stopping
        if self.hparams.patience is not None:
            early_stop = pl.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=self.hparams.patience,
                mode="min",
                min_delta=0.00,
            )
            callbacks.append(early_stop)
        return callbacks

In [None]:
# Path to your checkpoint
checkpoint_path = '/home/shuwei.yeh/deepclean/results/K1_train_test/lightning_logs/version_34/checkpoints/9-1270.ckpt'

# Load the model from the checkpoint, providing the required components
model = DeepClean.load_from_checkpoint(
    checkpoint_path,
    arch=arch,
    loss=loss,
    metric=metric
)

In [None]:
import os
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger  # If you're using Weights & Biases for logging
from pytorch_lightning.callbacks import EarlyStopping

from train.model import DeepClean
from train.data import DeepCleanDataset
from train.callbacks import PsdPlotter, ModelCheckpoint

# Set a random seed for reproducibility
seed_everything(101588)

def train_model():
    # Initialize the data module
    data_module = DeepCleanDataset(
        fname='/home/shuwei.yeh/deepclean/data/K-K1_lldata-1369291863-12288.hdf5',
        channels=['K1:CAL-CS_PROC_DARM_STRAIN_DBL_DQ', 'K1:PEM-MIC_OMC_BOOTH_OMC_Z_OUT_DQ'],  # Example channels
        kernel_length=0.25,
        freq_low=[55],  # Example frequency range
        freq_high=[65],
        batch_size=32,
        train_duration=4096,  # 1 hour of training data
        test_duration=8192,  # 10 minutes of test data
        valid_frac=0.33,  # 10% of training data for validation
        train_stride=0.0625,
        inference_sampling_rate=64,
        start_offset=0,
        filt_order=8
    )

    # Initialize the model
    model = DeepClean(
        arch=Architecture(),  # Initialize your architecture here
        loss=PsdRatio(sample_rate=4096, fftlength=4, freq_low=[55], freq_high=[65]),
        metric=OnlinePsdRatio(
            inference_sampling_rate=2048,
            edge_pad=0.25,
            filter_pad=0.5,
            sample_rate=4096,
            bandpass=lambda x: x,  # Dummy, replace with your bandpass function
            y_scaler=torch.nn.Identity()  # Dummy, replace with your scaler
        ),
        patience=20,
        save_top_k_models=3
    )

    # Define callbacks
    callbacks = [
        PsdPlotter(),
        ModelCheckpoint(monitor='val_loss', save_top_k=3, mode='min'),
        EarlyStopping(monitor='val_loss', patience=20, mode='min')
    ]

    # Optionally, define a logger
    logger = WandbLogger(project='DeepCleanProject', log_model='all')

    # Initialize the trainer
    trainer = Trainer(
        max_epochs=100,
        #gpus=1,  # or -1 to use all available GPUs, or None to run on CPU
        callbacks=callbacks,
        logger=logger,
        #progress_bar_refresh_rate=20  # Adjust as per your preference
    )

    # Train the model
    trainer.fit(model, datamodule=data_module)

    # Test the model
    trainer.test(model, datamodule=data_module)

if __name__ == '__main__':
    train_model()

In [None]:
import os
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from train.model import DeepClean
from train.data import DeepCleanDataset
from train.callbacks import PsdPlotter
from train.metrics import OnlinePsdRatio, PsdRatio
from train.architectures import Architecture  # Ensure this is correctly defined and imported

# Correcting the import for SpectralDensity and Metric if needed
from ml4gw.transforms import SpectralDensity
from torchmetrics import Metric

# Set a random seed for reproducibility
seed_everything(101588)


class PsdRatio(torch.nn.Module):
    def __init__(
        self,
        sample_rate: float,
        fftlength: float,
        freq_low: list[float],
        freq_high: list[float],
        overlap: Optional[float] = None,
        asd: bool = False,
    ) -> None:
        super().__init__()
        self.spectral_density = SpectralDensity(
            sample_rate,
            fftlength,
            overlap=overlap,
            average="median",
            fast=True,
        )
        self.asd = asd
        self.sample_rate = sample_rate

        N = int(fftlength * sample_rate / 2) + 1
        mask = torch.zeros((N,), dtype=torch.bool)
        for fl, fh in zip(freq_low, freq_high):
            low = int(fl * fftlength)
            high = int(fh * fftlength)
            mask[low : high + 1] = 1
        self.register_buffer("mask", mask)

    def forward(self, pred, strain):
        cleaned = strain - pred
        residual = self.spectral_density(cleaned.double())
        target = self.spectral_density(strain.double())

        ratio = residual / target
        ratio = ratio[:, self.mask]
        if self.asd:
            ratio = ratio**0.5
        loss = ratio.mean(dim=-1)
        return loss


class OnlinePsdRatio(Metric):
    def __init__(
        self,
        inference_sampling_rate: float,
        edge_pad: float,
        filter_pad: float,
        sample_rate: float,
        bandpass: Callable,
        y_scaler: torch.nn.Module,
    ) -> None:
        super().__init__()
        self.stride = int(sample_rate / inference_sampling_rate)
        self.filter_pad = int(filter_pad * sample_rate)
        self.edge_pad = int(edge_pad * sample_rate)
        self.sample_rate = sample_rate

        self.loss_fn = None
        self.bandpass = bandpass
        self.y_scaler = y_scaler

        self.add_state("predictions", default=[])
        self.add_state("strain", default=[])

    def update(self, y, kind):
        if self.loss_fn is None:
            raise ValueError("Must provide loss_fn before calling update")
        getattr(self, kind).append(y)

    def clean(self):
        # first build our overlapping predictions
        # into a single timeseries of noise predictions
        size = sum([i.numel() for i in self.strain])
        batch_size = len(self.predictions[0])
        device = self.predictions[0].device
        dtype = self.predictions[0].dtype

        y_pred = torch.zeros(
            (size - self.edge_pad,), device=device, dtype=dtype
        )

        # for each predicted window, only slice out
        # the single stride of new data from it that's
        # sufficiently far from the edge to be considered
        # "safe" and place it in the corresponding position
        # in the full timeseries. We can do this array-style
        # with some fancy indexing. We'll start by building
        # the array of indices we'll grab from the predicted batches
        get_idx = torch.arange(self.stride, device=device)
        offset = int(self.sample_rate) - self.edge_pad - self.stride
        get_idx += offset

        # then turn this into a matrix of indices where
        # each row of predictions will go in the timeseries
        set_idx = get_idx.view(1, -1).repeat(batch_size, 1)
        batch_offset = torch.arange(batch_size, device=device)
        set_idx += batch_offset[:, None] * self.stride

        for i, y in enumerate(self.predictions):
            sidx = set_idx[: len(y)]
            y_pred[sidx + i * batch_size * self.stride] = y[:, get_idx]

            # for the very first frame, we have no choice
            # but to fill the left side with our predictions.
            # This won't really matter since we don't end up
            # measuring ourselves on this frame, but we'll need
            # it for providing filter padding.
            if not i:
                y_pred[:offset] = y[0, :offset]

        # now clean the target timeseries in the
        # online fashion, one frame at a time, plus
        # some filter settle-in padding on each side.
        # Ignore the first and last frames to account
        # for this filter settle-in.
        num_frames = int((len(y_pred) - self.filter_pad) // self.sample_rate)
        noise = []
        for i in range(1, num_frames - 1):
            start = int(i * self.sample_rate) - self.filter_pad
            stop = int((i + 1) * self.sample_rate) + self.filter_pad
            noise.append(y_pred[start:stop])

        # postprocess, doing the bandpass filtering back
        # in numpy because torchaudio won't work
        noise = torch.stack(noise)
        noise = self.y_scaler(noise, reverse=True)
        noise = self.bandpass(noise.cpu().numpy())
        noise = torch.tensor(noise, device=device)

        # slice out the filter padding so that the
        # frames in each row are no longer overlapping,
        # then reshape them to a proper timeseries
        noise = noise[:, self.filter_pad : -self.filter_pad]
        noise = noise.reshape(1, -1)

        # reshape our raw strain into a timeseries
        raw = torch.cat(self.strain, dim=0)[1 : num_frames - 1]
        raw = raw.view(1, -1)
        return noise, raw

    def compute(self, reduce: bool = True):
        noise, raw = self.clean()
        if reduce:
            return self.loss_fn(noise, raw).mean()
        return noise, raw


# Set a random seed for reproducibility
seed_everything(101588)

def train_model():
    # Initialize the data module with corrected paths and parameters
    data_module = DeepCleanDataset(
        fname='/home/shuwei.yeh/deepclean/data/K-K1_lldata-1369291863-12288.hdf5',  # Correct path
        channels=['K1:CAL-CS_PROC_DARM_STRAIN_DBL_DQ', 'K1:PEM-MIC_OMC_BOOTH_OMC_Z_OUT_DQ'],  # Correct channels
        kernel_length=0.25,
        freq_low=[55],
        freq_high=[65],
        batch_size=32,
        train_duration=4096,  # Adjusted to seconds if needed
        test_duration=8192,  # Adjusted to seconds if needed
        valid_frac=0.33,
        train_stride=0.0625,
        inference_sampling_rate=64,
        start_offset=0,
        filt_order=8
    )

    # Initialize the model with corrected metrics and architecture
    model = DeepClean(
        arch=Architecture(),  # Correctly initialize your architecture here
        loss=PsdRatio(sample_rate=4096, fftlength=4, freq_low=[55], freq_high=[65]),
        metric=OnlinePsdRatio(
            inference_sampling_rate=64,  # Corrected to match data_module's rate
            edge_pad=0.25,
            filter_pad=0.5,
            sample_rate=4096,
            bandpass=[55, 65],  # Replace with actual function
            y_scaler=torch.nn.Module()  # Replace with actual scaler
        ),
        patience=20,
        save_top_k_models=3
    )

    # Define callbacks with corrected ModelCheckpoint
    callbacks = [
        PsdPlotter(),
        ModelCheckpoint(dirpath='/home/shuwei.yeh/deepclean/results/K1_train_test/lightning_logs/version_34/checkpoints/', filename='{epoch}-{val_loss:.2f}', monitor='val_loss', save_top_k=3, mode='min'),
        EarlyStopping(monitor='val_loss', patience=20, mode='min')
    ]

    # Optionally, define a logger
    logger = WandbLogger(project='DeepCleanProject', log_model='all')

    # Initialize the trainer with GPU configuration if available
    trainer = Trainer(
        max_epochs=100,
        # gpus=1 if torch.cuda.is_available() else 0,  # Automatically use GPU if available
        callbacks=callbacks,
        logger=logger,
        # progress_bar_refresh_rate=20  # Uncommented and adjusted
    )

    # Train the model
    trainer.fit(model, datamodule=data_module)

    # Test the model
    trainer.test(datamodule=data_module)  # Updated to pass datamodule directly

if __name__ == '__main__':
    train_model()