In [None]:
from __future__ import annotations

import datetime
import logging
import pathlib

import numpy as np
import pandas as pd
import pytorch_lightning as lightning
import torch.utils.data

from src import vak

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def get_split_dur(df: pd.DataFrame, split: str) -> float:
    """Get duration of a split in a dataset from a pandas DataFrame representing the dataset."""
    return df[df["split"] == split]["duration"].sum()

In [None]:
def get_trainer(
    max_epochs: int,
    ckpt_root: str | pathlib.Path,
    ckpt_step: int,
    log_save_dir: str | pathlib.Path,
    device: str = "cuda",
) -> lightning.Trainer:
    """Returns an instance of ``lightning.Trainer``
    with a default set of callbacks.
    Used by ``vak.core`` functions."""
    # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691
    if device == "cuda":
        accelerator = "gpu"
    else:
        accelerator = "auto"

    ckpt_callback = lightning.callbacks.ModelCheckpoint(
        dirpath=ckpt_root,
        filename="checkpoint",
        every_n_train_steps=ckpt_step,
        save_last=True,
        verbose=True,
    )
    ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
    ckpt_callback.FILE_EXTENSION = ".pt"

    val_ckpt_callback = lightning.callbacks.ModelCheckpoint(
        monitor="val_loss",
        dirpath=ckpt_root,
        save_top_k=1,
        mode="min",
        filename="min-val-loss-checkpoint",
        auto_insert_metric_name=False,
        verbose=True,
    )
    val_ckpt_callback.FILE_EXTENSION = ".pt"

    callbacks = [
        ckpt_callback,
        val_ckpt_callback,
    ]

    logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir)

    trainer = lightning.Trainer(
        max_epochs=max_epochs,
        accelerator=accelerator,
        logger=logger,
        callbacks=callbacks,
    )
    return trainer

In [None]:
class SpectrogramPipe(torch.utils.data.Dataset):
    """Pipeline for loading samples from a dataset of spectrograms
    
    This is a simplified version of ``vak.datasets.parametric_umap.ParametricUmapInferenceDataset``.
    """
    def __init__(
        self,
        data: npt.NDArray,
        dataset_df: pd.DataFrame,
        transform: Callable | None = None,
    ):
        self.data = data
        self.dataset_df = dataset_df
        self.transform = transform

    @property
    def duration(self):
        return self.dataset_df["duration"].sum()

    def __len__(self):
        return self.data.shape[0]

    @property
    def shape(self):
        tmp_x_ind = 0
        tmp_item = self.__getitem__(tmp_x_ind)
        return tmp_item["x"].shape

    def __getitem__(self, index):
        x = self.data[index]
        df_index = self.dataset_df.index[index]
        if self.transform:
            x = self.transform(x)
        return {"x": x, "df_index": df_index}

    @classmethod
    def from_dataset_path(
        cls,
        dataset_path: str | pathlib.Path,
        split: str,
        transform: Callable | None = None,
    ):
        import vak.datasets  # import here just to make classmethod more explicit

        dataset_path = pathlib.Path(dataset_path)
        metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path(
            dataset_path
        )

        dataset_csv_path = dataset_path / metadata.dataset_csv_filename
        dataset_df = pd.read_csv(dataset_csv_path)
        split_df = dataset_df[dataset_df.split == split]

        data = np.stack(
            [
                np.load(dataset_path / spect_path)
                for spect_path in split_df.spect_path.values
            ]
        )
        return cls(
            data,
            split_df,
            transform=transform,
        )

In [None]:
dataset_path = pathlib.Path(
    './tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/ConvEncoderUMAP/032312-vak-dimensionality-reduction-dataset-generated-231010_165846/'
)


In [None]:
metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path(
    dataset_path
)
dataset_csv_path = dataset_path / metadata.dataset_csv_filename
dataset_df = pd.read_csv(dataset_csv_path)


In [None]:
val_step = 2000

In [None]:
results_path = pathlib.Path(
    './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/AVA'
)
results_path.mkdir(exist_ok=True)

In [None]:
# ---------------- load training data  -----------------------------------------------------------------------------

# below, if we're going to train network to predict unlabeled segments, then
# we need to include a class for those unlabeled segments in labelmap,
# the mapping from labelset provided by user to a set of consecutive
# integers that the network learns to predict
train_dur = get_split_dur(dataset_df, "train")
print(
    f"Total duration of training split from dataset (in s): {train_dur}",
)


train_transform_params = {}
transform = vak.transforms.defaults.get_default_transform(
    "ConvEncoderUMAP", "train", train_transform_params
)


train_dataset_params = {}
train_dataset = SpectrogramPipe.from_dataset_path(
    dataset_path=dataset_path,
    split="train",
    transform=transform,
    **train_dataset_params,
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    shuffle=True,
    batch_size=64,
    num_workers=16,
)

In [None]:
# ---------------- load validation set (if there is one) -----------------------------------------------------------


val_transform_params = {}
transform = vak.transforms.defaults.get_default_transform(
    "ConvEncoderUMAP", "eval", val_transform_params
)
val_dataset_params = {}
val_dataset = SpectrogramPipe.from_dataset_path(
    dataset_path=dataset_path,
    split="val",
    transform=transform,
    **val_dataset_params,
)
print(
    f"Duration of ParametricUMAPDataset used for validation, in seconds: {val_dataset.duration}",
)
val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    shuffle=False,
    batch_size=64,
    num_workers=16,
)

In [None]:
device = vak.common.device.get_default()

model = vak.models.get(
    "AVA",
    config={"network": {}, "optimizer": {"lr": 0.001}},
    input_shape=train_dataset.shape,
)

results_model_root = results_path.joinpath("AVA")
results_model_root.mkdir(exist_ok=True)
ckpt_root = results_model_root.joinpath("checkpoints")
ckpt_root.mkdir(exist_ok=True)

trainer = get_trainer(
    max_epochs=50,
    log_save_dir=results_model_root,
    device=device,
    ckpt_root=ckpt_root,
    ckpt_step=250,
)

In [None]:
trainer.fit(
    model=model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)