In [1]:
from __future__ import annotations

import json
import logging
import os
import pathlib
from typing import Callable, ClassVar, Mapping

import pandas as pd
import torch
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as lightning
from torchmetrics import Accuracy
from torch.nn import functional as F

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms

import vak

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class FrameClassificationModel(lightning.LightningModule):
    def __init__(
        self,
        labelmap: Mapping,
        network: torch.nn.Module | dict | None = None,
        loss: torch.nn.Module | Callable | None = None,
        metrics: dict | None = None,
        post_tfm: Callable | None = None,
    ):
        super().__init__()

        self.network = network
        self.loss = loss
        self.optimizer = self.configure_optimizers()
        self.metrics = metrics

        self.labelmap = labelmap
        # replace any multiple character labels in mapping
        # with single-character labels
        # so that we do not affect edit distance computation
        # see https://github.com/NickleDave/vak/issues/373
        labelmap_keys = [lbl for lbl in labelmap.keys() if lbl != "unlabeled"]
        if any(
            [len(label) > 1 for label in labelmap_keys]
        ):  # only re-map if necessary
            # (to minimize chance of knock-on bugs)
            self.eval_labelmap = vak.common.labels.multi_char_labels_to_single_char(
                labelmap
            )
        else:
            self.eval_labelmap = labelmap

        self.to_labels_eval = vak.transforms.frame_labels.ToLabels(
            self.eval_labelmap
        )
        self.post_tfm = post_tfm

    def configure_optimizers(self):
        """Returns the model's optimizer.

        Method required by ``lightning.LightningModule``.
        This method returns the ``optimizer`` instance passed into ``__init__``.
        If None was passed in, an instance that was created
        with default arguments will be returned.
        """
        return self.optimizer

    def training_step(self, batch: tuple, batch_idx: int):
        x, y = batch[0], batch[1]
        out = self.network(x)
        loss = self.loss(out, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch: tuple, batch_idx: int):
        x, y = batch["frames"], batch["frame_labels"]
        # remove "batch" dimension added by collate_fn to x
        # we keep for y because loss still expects the first dimension to be batch
        # TODO: fix this weirdness. Diff't collate_fn?
        if x.ndim in (5, 4):
            if x.shape[0] == 1:
                x = torch.squeeze(x, dim=0)
        else:
            raise ValueError(f"invalid shape for x: {x.shape}")

        out = self.network(x)
        # permute and flatten out
        # so that it has shape (1, number classes, number of time bins)
        # ** NOTICE ** just calling out.reshape(1, out.shape(1), -1) does not work, it will change the data
        out = out.permute(1, 0, 2)
        out = torch.flatten(out, start_dim=1)
        out = torch.unsqueeze(out, dim=0)
        # reduce to predictions, assuming class dimension is 1
        y_pred = torch.argmax(
            out, dim=1
        )  # y_pred has dims (batch size 1, predicted label per time bin)

        if "padding_mask" in batch:
            padding_mask = batch[
                "padding_mask"
            ]  # boolean: 1 where valid, 0 where padding
            # remove "batch" dimension added by collate_fn
            # because this extra dimension just makes it confusing to use the mask as indices
            if padding_mask.ndim == 2:
                if padding_mask.shape[0] == 1:
                    padding_mask = torch.squeeze(padding_mask, dim=0)
            else:
                raise ValueError(
                    f"invalid shape for padding mask: {padding_mask.shape}"
                )

            out = out[:, :, padding_mask]
            y_pred = y_pred[:, padding_mask]

        y_labels = self.to_labels_eval(y.cpu().numpy())
        y_pred_labels = self.to_labels_eval(y_pred.cpu().numpy())

        if self.post_tfm:
            y_pred_tfm = self.post_tfm(
                y_pred.cpu().numpy(),
            )
            y_pred_tfm_labels = self.to_labels_eval(y_pred_tfm)
            # convert back to tensor so we can compute accuracy
            y_pred_tfm = torch.from_numpy(y_pred_tfm).to(self.device)

        # TODO: figure out smarter way to do this
        for metric_name, metric_callable in self.metrics.items():
            if metric_name == "loss":
                self.log(
                    f"val_{metric_name}",
                    metric_callable(out, y),
                    batch_size=1,
                    on_step=True,
                    sync_dist=True,
                )
            elif metric_name == "acc":
                self.log(
                    f"val_{metric_name}",
                    metric_callable(y_pred, y),
                    batch_size=1,
                    on_step=True,
                    sync_dist=True,
                )
                if self.post_tfm:
                    self.log(
                        f"val_{metric_name}_tfm",
                        metric_callable(y_pred_tfm, y),
                        batch_size=1,
                        on_step=True,
                        sync_dist=True,
                    )
            elif (
                metric_name == "levenshtein"
                or metric_name == "character_error_rate"
            ):
                self.log(
                    f"val_{metric_name}",
                    metric_callable(y_pred_labels, y_labels),
                    batch_size=1,
                    on_step=True,
                    sync_dist=True,
                )
                if self.post_tfm:
                    self.log(
                        f"val_{metric_name}_tfm",
                        metric_callable(y_pred_tfm_labels, y_labels),
                        batch_size=1,
                        on_step=True,
                        sync_dist=True,
                    )

    def configure_optimizers(self):
        return torch.optim.Adam(lr=0.003, params=self.parameters())

In [3]:
dataset_path = './tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-231010_165438/'

In [4]:
dataset_path = pathlib.Path(dataset_path)
metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(
    dataset_path
)
dataset_csv_path = dataset_path / metadata.dataset_csv_filename
dataset_df = pd.read_csv(dataset_csv_path)

In [5]:
labelmap_path = dataset_path / "labelmap.json"
with labelmap_path.open("r") as f:
    labelmap = json.load(f)

In [6]:
spect_standardizer = vak.transforms.StandardizeSpect.fit_dataset_path(
    dataset_path,
    split="train",
)

In [7]:
train_transform_params = {}
train_transform_params.update({"spect_standardizer": spect_standardizer})
transform, target_transform = vak.transforms.defaults.get_default_transform(
    'TweetyNet', "train", transform_kwargs=train_transform_params
)

In [8]:
train_dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path(
    dataset_path=dataset_path,
    split="train",
    transform=transform,
    target_transform=target_transform,
    window_size=44,
)

In [9]:
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    shuffle=True,
    batch_size=16,
    num_workers=16,
)


In [10]:
val_transform_params = {'window_size': 44}
val_transform_params.update({"spect_standardizer": spect_standardizer})
item_transform = vak.transforms.defaults.get_default_transform(
    'TweetyNet', "eval", val_transform_params
)

In [11]:
val_dataset = vak.datasets.frame_classification.FramesDataset.from_dataset_path(
    dataset_path=dataset_path,
    split="val",
    item_transform=item_transform,
)

In [12]:
val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    shuffle=False,
    # batch size 1 because each spectrogram reshaped into a batch of windows
    batch_size=1,
    num_workers=16,
)

In [13]:
network = vak.nets.TweetyNet(
    num_classes=len(labelmap),
    num_freqbins=train_dataset.shape[1],
    num_input_channels=train_dataset.shape[0]
)

In [14]:
loss = torch.nn.CrossEntropyLoss()

In [15]:
metrics = {
    "acc": vak.metrics.Accuracy(),
    "levenshtein": vak.metrics.Levenshtein(),
    "character_error_rate": vak.metrics.CharacterErrorRate(),
    "loss": torch.nn.CrossEntropyLoss(),
}

In [16]:
model = FrameClassificationModel(
    network=network,
    loss=loss,
    metrics=metrics,
    labelmap=labelmap
)

In [17]:
ckpt_callback = lightning.callbacks.ModelCheckpoint(
    dirpath='.',
    filename="checkpoint",
    every_n_train_steps=10,
    save_last=True,
    verbose=True,
)
ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
ckpt_callback.FILE_EXTENSION = ".pt"

val_ckpt_callback = lightning.callbacks.ModelCheckpoint(
    monitor="val_acc",
    dirpath='.',
    save_top_k=1,
    mode="max",
    filename="best",
    auto_insert_metric_name=False,
    verbose=True,
)
val_ckpt_callback.FILE_EXTENSION = ".pt"

early_stopping = lightning.callbacks.EarlyStopping(
    mode="max",
    monitor="val_acc",
    patience=4,
    verbose=True,
)

callbacks = [ckpt_callback, val_ckpt_callback, early_stopping]

In [18]:
logger = lightning.loggers.TensorBoardLogger(save_dir='.')

In [19]:
trainer = lightning.Trainer(
    callbacks=callbacks,
    val_check_interval=100,
    max_steps=5000,
    accelerator='cuda',
    logger=logger,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [20]:
trainer.fit(model, train_loader, val_loader)

Missing logger folder: ./lightning_logs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | network | TweetyNet        | 319 K 
1 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
319 K     Trainable params
0         Non-trainable params
319 K     Total params
1.277     Total estimated model params size (MB)


                                                                                                                                                      

  return torch.tensor(rate, dtype=torch.float32)


Epoch 0:   6%|█████▉                                                                                      | 100/1560 [00:02<00:35, 41.15it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 30.39it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 34.49it/s][A

Metric val_acc improved. New best score: 0.886


Epoch 0:   6%|█████▉                                                                                      | 100/1560 [00:03<00:44, 32.88it/s, v_num=0]
                                                                                                                                                      [A

Epoch 0, global step 100: 'val_acc' reached 0.88571 (best 0.88571), saving model to '/home/pimienta/Documents/repos/coding/vocalpy/vak-vocalpy/best.pt' as top 1


Epoch 0:  13%|███████████▊                                                                                | 200/1560 [00:04<00:32, 42.04it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 38.22it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 43.16it/s][A

Metric val_acc improved by 0.063 >= min_delta = 0.0. New best score: 0.949


Epoch 0:  13%|███████████▊                                                                                | 200/1560 [00:05<00:36, 37.46it/s, v_num=0]
                                                                                                                                                      [A

Epoch 0, global step 200: 'val_acc' reached 0.94883 (best 0.94883), saving model to '/home/pimienta/Documents/repos/coding/vocalpy/vak-vocalpy/best.pt' as top 1


Epoch 0:  19%|█████████████████▋                                                                          | 300/1560 [00:06<00:29, 43.03it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 41.72it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 46.22it/s][A

Metric val_acc improved by 0.005 >= min_delta = 0.0. New best score: 0.954


Epoch 0:  19%|█████████████████▋                                                                          | 300/1560 [00:07<00:31, 39.75it/s, v_num=0]
                                                                                                                                                      [A

Epoch 0, global step 300: 'val_acc' reached 0.95399 (best 0.95399), saving model to '/home/pimienta/Documents/repos/coding/vocalpy/vak-vocalpy/best.pt' as top 1


Epoch 0:  26%|███████████████████████▌                                                                    | 400/1560 [00:09<00:26, 43.17it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 44.00it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 48.99it/s][A

Metric val_acc improved by 0.012 >= min_delta = 0.0. New best score: 0.966


Epoch 0:  26%|███████████████████████▌                                                                    | 400/1560 [00:09<00:28, 40.60it/s, v_num=0]
                                                                                                                                                      [A

Epoch 0, global step 400: 'val_acc' reached 0.96578 (best 0.96578), saving model to '/home/pimienta/Documents/repos/coding/vocalpy/vak-vocalpy/best.pt' as top 1


Epoch 0:  32%|█████████████████████████████▍                                                              | 500/1560 [00:11<00:24, 43.16it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 44.04it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 49.80it/s][A

Metric val_acc improved by 0.007 >= min_delta = 0.0. New best score: 0.973


Epoch 0:  32%|█████████████████████████████▍                                                              | 500/1560 [00:12<00:25, 41.07it/s, v_num=0]
                                                                                                                                                      [A

Epoch 0, global step 500: 'val_acc' reached 0.97262 (best 0.97262), saving model to '/home/pimienta/Documents/repos/coding/vocalpy/vak-vocalpy/best.pt' as top 1


Epoch 0:  38%|███████████████████████████████████▍                                                        | 600/1560 [00:13<00:22, 43.42it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 46.28it/s][A
Epoch 0:  38%|███████████████████████████████████▍                                                        | 600/1560 [00:14<00:23, 41.63it/s, v_num=0][A
                                                                                                                                                      [A

Epoch 0, global step 600: 'val_acc' was not in top 1


Epoch 0:  45%|█████████████████████████████████████████▎                                                  | 700/1560 [00:16<00:19, 43.43it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 46.69it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 52.04it/s][A

Metric val_acc improved by 0.005 >= min_delta = 0.0. New best score: 0.978


Epoch 0:  45%|█████████████████████████████████████████▎                                                  | 700/1560 [00:16<00:20, 41.93it/s, v_num=0]
                                                                                                                                                      [A

Epoch 0, global step 700: 'val_acc' reached 0.97803 (best 0.97803), saving model to '/home/pimienta/Documents/repos/coding/vocalpy/vak-vocalpy/best.pt' as top 1


Epoch 0:  51%|███████████████████████████████████████████████▏                                            | 800/1560 [00:18<00:17, 43.47it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 45.36it/s][A
Epoch 0:  51%|███████████████████████████████████████████████▏                                            | 800/1560 [00:18<00:18, 42.12it/s, v_num=0][A
                                                                                                                                                      [A

Epoch 0, global step 800: 'val_acc' was not in top 1


Epoch 0:  58%|█████████████████████████████████████████████████████                                       | 900/1560 [00:20<00:15, 43.62it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 46.16it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 48.92it/s][A

Metric val_acc improved by 0.002 >= min_delta = 0.0. New best score: 0.980


Epoch 0:  58%|█████████████████████████████████████████████████████                                       | 900/1560 [00:21<00:15, 42.38it/s, v_num=0]
                                                                                                                                                      [A

Epoch 0, global step 900: 'val_acc' reached 0.97985 (best 0.97985), saving model to '/home/pimienta/Documents/repos/coding/vocalpy/vak-vocalpy/best.pt' as top 1


Epoch 0:  64%|██████████████████████████████████████████████████████████▎                                | 1000/1560 [00:22<00:12, 43.66it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 46.13it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.36it/s][A

Metric val_acc improved by 0.004 >= min_delta = 0.0. New best score: 0.983


Epoch 0:  64%|██████████████████████████████████████████████████████████▎                                | 1000/1560 [00:23<00:13, 41.93it/s, v_num=0]
                                                                                                                                                      [A

Epoch 0, global step 1000: 'val_acc' reached 0.98347 (best 0.98347), saving model to '/home/pimienta/Documents/repos/coding/vocalpy/vak-vocalpy/best.pt' as top 1


Epoch 0:  71%|████████████████████████████████████████████████████████████████▏                          | 1100/1560 [00:25<00:10, 43.11it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 48.65it/s][A
Epoch 0:  71%|████████████████████████████████████████████████████████████████▏                          | 1100/1560 [00:26<00:10, 42.17it/s, v_num=0][A
                                                                                                                                                      [A

Epoch 0, global step 1100: 'val_acc' was not in top 1


Epoch 0:  77%|██████████████████████████████████████████████████████████████████████                     | 1200/1560 [00:27<00:08, 43.30it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 45.22it/s][A
Epoch 0:  77%|██████████████████████████████████████████████████████████████████████                     | 1200/1560 [00:28<00:08, 42.43it/s, v_num=0][A
                                                                                                                                                      [A

Epoch 0, global step 1200: 'val_acc' was not in top 1


Epoch 0:  83%|███████████████████████████████████████████████████████████████████████████▊               | 1300/1560 [00:29<00:05, 43.39it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 49.26it/s][A
Epoch 0:  83%|███████████████████████████████████████████████████████████████████████████▊               | 1300/1560 [00:30<00:06, 42.60it/s, v_num=0][A
                                                                                                                                                      [A

Epoch 0, global step 1300: 'val_acc' was not in top 1


Epoch 0:  90%|█████████████████████████████████████████████████████████████████████████████████▋         | 1400/1560 [00:32<00:03, 43.51it/s, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                  | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████████████████████████████████████████████                                             | 1/2 [00:00<00:00, 46.85it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 51.69it/s][A

Monitored metric val_acc did not improve in the last 4 records. Best score: 0.983. Signaling Trainer to stop.


Epoch 0:  90%|█████████████████████████████████████████████████████████████████████████████████▋         | 1400/1560 [00:32<00:03, 42.75it/s, v_num=0]
                                                                                                                                                      [A

Epoch 0, global step 1400: 'val_acc' was not in top 1


Epoch 0:  90%|█████████████████████████████████████████████████████████████████████████████████▋         | 1400/1560 [00:32<00:03, 42.74it/s, v_num=0]
