## Tutorial on training a HTS-AT model for audio classification on the ESC-50 Dataset

Referece: 

[HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection, ICASSP 2022](https://arxiv.org/abs/2202.00874)

Following the HTS-AT's paper, in this tutorial, we would show how to use the HST-AT in the training of the ESC-50 Dataset.

The [ESC-50 dataset](https://github.com/karolpiczak/ESC-50) is a labeled collection of 2000 environmental audio recordings suitable for benchmarking methods of environmental sound classification. The dataset consists of 5-second-long recordings organized into 50 semantical classes (with 40 examples per class) loosely arranged into 5 major categories

Before running this tutorial, please make sure that you install the below packages by following steps:

1. download [the codebase](https://github.com/RetroCirce/HTS-Audio-Transformer), and put this tutorial notebook inside the codebase folder.

2. In the github code folder:

    > pip install -r requirements.txt

3. We do not include the installation of PyTorch in the requirment, since different machines require different vereions of CUDA and Toolkits. So make sure you install the PyTorch from [the official guidance](https://pytorch.org/).

4. Install the 'SOX' and the 'ffmpeg', we recommend that you run this code in Linux inside the Conda environment. In that, you can install them by:

    > sudo apt install sox
    
    > conda install -c conda-forge ffmpeg


In [1]:
# import basic packages
import os
import numpy as np
import wget
import sys
import gdown
import zipfile
import librosa
# in the notebook, we only can use one GPU
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# Load the model package
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import warnings

from utils import create_folder, dump_config, process_idc
import esc_config as config
from sed_model import SEDWrapper, Ensemble_SEDWrapper
from data_generator import ESC_Dataset
from model.htsat import HTSAT_Swin_Transformer



In [2]:
# Data Preparation
class data_prep(pl.LightningDataModule):
    def __init__(self, train_dataset, eval_dataset, device_num):
        super().__init__()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.device_num = device_num

    def train_dataloader(self):
        train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None
        train_loader = DataLoader(
            dataset = self.train_dataset,
            num_workers = 0,
            #batch_size = config.batch_size // self.device_num,
            shuffle = False,
            sampler = train_sampler
        )
        return train_loader
    def val_dataloader(self):
        eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
        eval_loader = DataLoader(
            dataset = self.eval_dataset,
            num_workers = 0,
            #batch_size = config.batch_size // self.device_num,
            shuffle = False,
            sampler = eval_sampler
        )
        return eval_loader
    def test_dataloader(self):
        test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
        test_loader = DataLoader(
            dataset = self.eval_dataset,
            num_workers = 0,
            #batch_size = config.batch_size // self.device_num,
            shuffle = False,
            sampler = test_sampler
        )
        return test_loader

In [3]:
# Set the workspace
device_num = 1 #torch.cuda.device_count()
print("each batch size:", config.batch_size // device_num)

full_dataset = np.load(os.path.join("esc-50/esc-50-data-custom.npy"), allow_pickle = True)

# set exp folder
exp_dir = os.path.join(config.workspace, "results", config.exp_name)
checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint")
if not config.debug:
    create_folder(os.path.join(config.workspace, "results"))
    create_folder(exp_dir)
    create_folder(checkpoint_dir)
    dump_config(config, os.path.join(exp_dir, config.exp_name), False)
    
print("Using ESC")
dataset = ESC_Dataset(
    dataset = full_dataset,
    config = config,
    eval_mode = False
)
eval_dataset = ESC_Dataset(
    dataset = full_dataset,
    config = config,
    eval_mode = True
)

audioset_data = data_prep(dataset, eval_dataset, device_num)
checkpoint_callback = ModelCheckpoint(
    monitor = "acc",
    filename='l-{epoch:d}-{acc:.3f}',
    save_top_k = 20,
    mode = "max"
)

each batch size: 64
Using ESC
total dataset size: 1600
total dataset size: 400


In [4]:
# Set the Trainer
trainer = pl.Trainer(
    deterministic=False,
    default_root_dir = checkpoint_dir,
    val_check_interval = 1.0,
    max_epochs = config.max_epoch,
    auto_lr_find = True,    
    sync_batchnorm = True,
    callbacks = [checkpoint_callback],
    accelerator = "ddp" if device_num > 1 else None,
    devices=1,
    num_sanity_val_steps = 0,
    resume_from_checkpoint = None, 
    replace_sampler_ddp = False,
    gradient_clip_val=1.0
)

sed_model = HTSAT_Swin_Transformer(
    spec_size=config.htsat_spec_size,
    patch_size=config.htsat_patch_size,
    in_chans=1,
    num_classes=config.classes_num,
    window_size=config.htsat_window_size,
    config = config,
    depths = config.htsat_depth,
    embed_dim = config.htsat_dim,
    patch_stride=config.htsat_stride,
    num_heads=config.htsat_num_head
)

model = SEDWrapper(
    sed_model = sed_model, 
    config = config,
    dataset = dataset
)

if config.resume_checkpoint is not None:
    print("Load Checkpoint from ", config.resume_checkpoint)
    ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
    ckpt["state_dict"].pop("sed_model.head.weight")
    ckpt["state_dict"].pop("sed_model.head.bias")
    # finetune on the esc and spv2 dataset
    ckpt["state_dict"].pop("sed_model.tscam_conv.weight")
    ckpt["state_dict"].pop("sed_model.tscam_conv.bias")
    model.load_state_dict(ckpt["state_dict"], strict=False)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Load Checkpoint from  ./esc-50/HTSAT_ESC_exp=1_fold=0_acc=0.970.ckpt


In [5]:
## Training the model
# You can set different fold index by setting 'esc_fold' to any number from 0-4 in esc_config.py
trainer.fit(model, audioset_data)


  | Name      | Type                   | Params
-----------------------------------------------------
0 | sed_model | HTSAT_Swin_Transformer | 28.9 M
-----------------------------------------------------
27.8 M    Trainable params
1.1 M     Non-trainable params
28.9 M    Total params
115.404   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.665}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.62}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.505}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.615}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.59}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.49}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.645}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5625}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.44}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.54}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.62}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6175}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.615}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5825}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.645}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.585}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.52}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.545}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.585}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.675}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.635}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5775}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.675}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5975}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.61}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5975}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.695}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5725}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.63}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.63}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6075}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5275}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.585}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.67}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.645}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.565}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.54}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5675}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5825}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6175}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.59}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6275}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6225}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.4575}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5375}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5575}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5175}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.575}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.4}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.515}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5625}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5075}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.49}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.615}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.575}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.4575}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.46}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.4275}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.665}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6375}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5425}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5025}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.545}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6525}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5675}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.53}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5725}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.65}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.485}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5575}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.58}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5075}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.565}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5675}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.595}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.525}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5775}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.575}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.545}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.635}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.4825}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6125}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6325}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5575}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.485}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5925}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.4675}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.625}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5525}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.535}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.5275}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.47}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.4825}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.4825}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6225}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6025}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.57}


Validation: 0it [00:00, ?it/s]

cpu {'acc': 0.6675}


`Trainer.fit` stopped: `max_epochs=100` reached.
