## 1. Import

In [None]:
# basic
import os
import gzip
import bz2

# python
import pickle
import pywt
import librosa
import numpy as np
from pathlib import Path
from scipy import signal
from scipy import optimize
from time import time
!pip install fcwt

# notebook
from IPython.display import Audio
import matplotlib.pyplot as plt 
import cv2

# PyTorch
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

## 2. Configurations

In [None]:
class Config:
    def __init__(self):
        # train env
        self.TEST_SOUNDSCAPE = Path('/kaggle/input/birdclef-2024/test_soundscapes')
        self.NO_SOUNDSCAPE = False
        
        # test env
        if 1 >= len(os.listdir(self.TEST_SOUNDSCAPE)):
            self.TEST_SOUNDSCAPE = Path('/kaggle/input/birdclef-2024/train_audio/asbfly')
            self.NO_SOUNDSCAPE = True

    model_name = "efficientnet_b0.ra_in1k"  # model will be used
    img_size = 224                          # input size. If it's 256, input image resize to 256x256
    n_folds = 5                             # number of folds
    interpolation = cv2.INTER_AREA          # specifying method of interpolation(dfault is cv2.INTER_LINEAR)
    max_epoch = 9                           # number of max epoch. 1epoch means going around the training dataset.
#     batch_size = 32                       # train batch size. Number of samples passed to the network in one training step
    batch_size = 1                          # test batch size. openvino can't response flexibility to leftovers of number of batches 
    lr = 1.0e-03                            # learning rate. determine step size when updating model's weight
    weight_decay = 1.0e-02                  # weight decay. Append regularization term for prevent over fitting
    es_patience = 5                         # Early Stopping
    seed = 1086                             # seed
    deterministic = True                    # deterministic, it returns same resurlts if the all configs are same.
    enable_amp = False                      # enable or disable the Automatic Mixed Precision
#     device = "cuda"                       # Device to use training. "cuda" is a NVIDIA GPU
    device = "cpu"                          # Device to use training. "cpu" is a cpu
    
    
    simple_training = True                  # only use few data with training, be enable in training
    simple_inferring = False                # only use few data with inferring
    n_simple = 100                          # number of data with simple training
    test = False                            # when inference
    
    DURATION = 5                            # length of cropped data
    LENGTH = DURATION
    
    MELSPEC_H = 128                         # horizontal melspectrogram resolution
    MELSPEC_H = 256
    MELSPEC_H = 128
    MELSPEC_H = 64
    
    TOP_DB = 100                            # maximum decibel to clip audio to
    MIN_RATING = 0.0                        # minimum rating

    SR = 32000                              # sample rate
    audio_len = DURATION*SR
    
    N_FFT = 2048 # 1024 # 5096              # STFT parameters
    HOP_LENGTH = 512
    window = 2048
    fmin = 20
    fmax = 16000
    
#     ONLY_ONCE = True                      # train with 1 data only
    ONLY_ONCE = False
#     VIEW = True                           # whether show processed data or not
    VIEW = False                           
#     VIEW_RAW = True                       # whether show raw data or not
    VIEW_RAW = False
    
    
    # Related to faster inference 
    INPUT_SHAPE: list[int] = [1, 1, 224, 224] 
    DUMMY_INPUT_TENSOR: torch.Tensor = torch.randn(*INPUT_SHAPE)
    DUMMY_INPUT_NUMPY_FP32: np.ndarray = DUMMY_INPUT_TENSOR.numpy()
    DUMMY_INPUT_NUMPY_FP16: np.ndarray = DUMMY_INPUT_NUMPY_FP32.astype(np.float16)
    OUTPUT_DIR_ONNX: Path = Path('./model/onnx')
    OUTPUT_DIR_OV: Path = Path('./model/ov')


CFG = Config()

### Wavelet Transfrom

In [None]:
import numpy as np
import torch
from torch import nn

from scipy import optimize
from scipy.special import factorial, gamma, hermitenorm
from timm.models.layers import conv2d_same



### from https://github.com/tomrunia/PyTorchWavelets/blob/master/wavelets_pytorch/wavelets.py
class Morlet(object):
    def __init__(self, w0=6):
        """w0 is the nondimensional frequency constant. If this is
        set too low then the wavelet does not sample very well: a
        value over 5 should be ok; Terrence and Compo set it to 6.
        """
        self.w0 = w0
        if w0 == 6:
            # value of C_d from TC98
            self.C_d = 0.776

    def __call__(self, *args, **kwargs):
        return self.time(*args, **kwargs)

    def time(self, t, s=1.0, complete=True):
        """
        Complex Morlet wavelet, centred at zero.
        Parameters
        ----------
        t : float
            Time. If s is not specified, this can be used as the
            non-dimensional time t/s.
        s : float
            Scaling factor. Default is 1.
        complete : bool
            Whether to use the complete or the standard version.
        Returns
        -------
        out : complex
            Value of the Morlet wavelet at the given time
        See Also
        --------
        scipy.signal.gausspulse
        Notes
        -----
        The standard version::
            pi**-0.25 * exp(1j*w*x) * exp(-0.5*(x**2))
        This commonly used wavelet is often referred to simply as the
        Morlet wavelet.  Note that this simplified version can cause
        admissibility problems at low values of `w`.
        The complete version::
            pi**-0.25 * (exp(1j*w*x) - exp(-0.5*(w**2))) * exp(-0.5*(x**2))
        The complete version of the Morlet wavelet, with a correction
        term to improve admissibility. For `w` greater than 5, the
        correction term is negligible.
        Note that the energy of the return wavelet is not normalised
        according to `s`.
        The fundamental frequency of this wavelet in Hz is given
        by ``f = 2*s*w*r / M`` where r is the sampling rate.
        """
        w = self.w0

        x = t / s

        output = np.exp(1j * w * x)

        if complete:
            output -= np.exp(-0.5 * (w ** 2))

        output *= np.exp(-0.5 * (x ** 2)) * np.pi ** (-0.25)

        return output

    # Fourier wavelengths
    def fourier_period(self, s):
        """Equivalent Fourier period of Morlet"""
        return 4 * np.pi * s / (self.w0 + (2 + self.w0 ** 2) ** 0.5)

    def scale_from_period(self, period):
        """
        Compute the scale from the fourier period.
        Returns the scale
        """
        # Solve 4 * np.pi * scale / (w0 + (2 + w0 ** 2) ** .5)
        #  for s to obtain this formula
        coeff = np.sqrt(self.w0 * self.w0 + 2)
        return (period * (coeff + self.w0)) / (4.0 * np.pi)

    # Frequency representation
    def frequency(self, w, s=1.0):
        """Frequency representation of Morlet.
        Parameters
        ----------
        w : float
            Angular frequency. If `s` is not specified, i.e. set to 1,
            this can be used as the non-dimensional angular
            frequency w * s.
        s : float
            Scaling factor. Default is 1.
        Returns
        -------
        out : complex
            Value of the Morlet wavelet at the given frequency
        """
        x = w * s
        # Heaviside mock
        Hw = np.array(w)
        Hw[w <= 0] = 0
        Hw[w > 0] = 1
        return np.pi ** -0.25 * Hw * np.exp((-((x - self.w0) ** 2)) / 2)

    def coi(self, s):
        """The e folding time for the autocorrelation of wavelet
        power at each scale, i.e. the timescale over which an edge
        effect decays by a factor of 1/e^2.
        This can be worked out analytically by solving
            |Y_0(T)|^2 / |Y_0(0)|^2 = 1 / e^2
        """
        return 2 ** 0.5 * s


class Paul(object):
    def __init__(self, m=4):
        """Initialise a Paul wavelet function of order `m`."""
        self.m = m

    def __call__(self, *args, **kwargs):
        return self.time(*args, **kwargs)

    def time(self, t, s=1.0):
        """
        Complex Paul wavelet, centred at zero.
        Parameters
        ----------
        t : float
            Time. If `s` is not specified, i.e. set to 1, this can be
            used as the non-dimensional time t/s.
        s : float
            Scaling factor. Default is 1.
        Returns
        -------
        out : complex
            Value of the Paul wavelet at the given time
        The Paul wavelet is defined (in time) as::
            (2 ** m * i ** m * m!) / (pi * (2 * m)!) \
                    * (1 - i * t / s) ** -(m + 1)
        """
        m = self.m
        x = t / s

        const = (2 ** m * 1j ** m * factorial(m)) / (np.pi * factorial(2 * m)) ** 0.5
        functional_form = (1 - 1j * x) ** -(m + 1)

        output = const * functional_form

        return output

    # Fourier wavelengths
    def fourier_period(self, s):
        """Equivalent Fourier period of Paul"""
        return 4 * np.pi * s / (2 * self.m + 1)

    def scale_from_period(self, period):
        raise NotImplementedError()

    # Frequency representation
    def frequency(self, w, s=1.0):
        """Frequency representation of Paul.
        Parameters
        ----------
        w : float
            Angular frequency. If `s` is not specified, i.e. set to 1,
            this can be used as the non-dimensional angular
            frequency w * s.
        s : float
            Scaling factor. Default is 1.
        Returns
        -------
        out : complex
            Value of the Paul wavelet at the given frequency
        """
        m = self.m
        x = w * s
        # Heaviside mock
        Hw = 0.5 * (np.sign(x) + 1)

        # prefactor
        const = 2 ** m / (m * factorial(2 * m - 1)) ** 0.5

        functional_form = Hw * (x) ** m * np.exp(-x)

        output = const * functional_form

        return output

    def coi(self, s):
        """The e folding time for the autocorrelation of wavelet
        power at each scale, i.e. the timescale over which an edge
        effect decays by a factor of 1/e^2.
        This can be worked out analytically by solving
            |Y_0(T)|^2 / |Y_0(0)|^2 = 1 / e^2
        """
        return s / 2 ** 0.5


class DOG(object):
    def __init__(self, m=2):
        """Initialise a Derivative of Gaussian wavelet of order `m`."""
        if m == 2:
            # value of C_d from TC98
            self.C_d = 3.541
        elif m == 6:
            self.C_d = 1.966
        else:
            pass
        self.m = m

    def __call__(self, *args, **kwargs):
        return self.time(*args, **kwargs)

    def time(self, t, s=1.0):
        """
        Return a Derivative of Gaussian wavelet,
        When m = 2, this is also known as the "Mexican hat", "Marr"
        or "Ricker" wavelet.
        It models the function::
            ``A d^m/dx^m exp(-x^2 / 2)``,
        where ``A = (-1)^(m+1) / (gamma(m + 1/2))^.5``
        and   ``x = t / s``.
        Note that the energy of the return wavelet is not normalised
        according to `s`.
        Parameters
        ----------
        t : float
            Time. If `s` is not specified, this can be used as the
            non-dimensional time t/s.
        s : scalar
            Width parameter of the wavelet.
        Returns
        -------
        out : float
            Value of the DOG wavelet at the given time
        Notes
        -----
        The derivative of the Gaussian has a polynomial representation:
        from http://en.wikipedia.org/wiki/Gaussian_function:
        "Mathematically, the derivatives of the Gaussian function can be
        represented using Hermite functions. The n-th derivative of the
        Gaussian is the Gaussian function itself multiplied by the n-th
        Hermite polynomial, up to scale."
        http://en.wikipedia.org/wiki/Hermite_polynomial
        Here, we want the 'probabilists' Hermite polynomial (He_n),
        which is computed by scipy.special.hermitenorm
        """
        x = t / s
        m = self.m

        # compute the Hermite polynomial (used to evaluate the
        # derivative of a Gaussian)
        He_n = hermitenorm(m)
        # gamma = scipy.special.gamma

        const = (-1) ** (m + 1) / gamma(m + 0.5) ** 0.5
        function = He_n(x) * np.exp(-(x ** 2) / 2) * np.exp(-1j * x)

        return const * function

    def fourier_period(self, s):
        """Equivalent Fourier period of derivative of Gaussian"""
        return 2 * np.pi * s / (self.m + 0.5) ** 0.5

    def scale_from_period(self, period):
        raise NotImplementedError()

    def frequency(self, w, s=1.0):
        """Frequency representation of derivative of Gaussian.
        Parameters
        ----------
        w : float
            Angular frequency. If `s` is not specified, i.e. set to 1,
            this can be used as the non-dimensional angular
            frequency w * s.
        s : float
            Scaling factor. Default is 1.
        Returns
        -------
        out : complex
            Value of the derivative of Gaussian wavelet at the
            given time
        """
        m = self.m
        x = s * w
        # gamma = scipy.special.gamma
        const = -(1j ** m) / gamma(m + 0.5) ** 0.5
        function = x ** m * np.exp(-(x ** 2) / 2)
        return const * function

    def coi(self, s):
        """The e folding time for the autocorrelation of wavelet
        power at each scale, i.e. the timescale over which an edge
        effect decays by a factor of 1/e^2.
        This can be worked out analytically by solving
            |Y_0(T)|^2 / |Y_0(0)|^2 = 1 / e^2
        """
        return 2 ** 0.5 * s


class Ricker(DOG):
    def __init__(self):
        """The Ricker, aka Marr / Mexican Hat, wavelet is a
        derivative of Gaussian order 2.
        """
        DOG.__init__(self, m=2)
        # value of C_d from TC98
        self.C_d = 3.541




### from https://www.kaggle.com/code/anjum48/continuous-wavelet-transform-cwt-in-pytorch#PyTorch-implementation
class CWT2(nn.Module):
    def __init__(
        self,
        dj=0.0625,
        dt=1 / 2048,
        wavelet=Morlet(),
        fmin: int = 20,
        fmax: int = 500,
        output_format="Magnitude",
        trainable=False,
        hop_length: int = 1,
    ):
        super().__init__()
        self.wavelet = wavelet

        self.dt = dt
        self.dj = dj
        self.fmin = fmin
        self.fmax = fmax
        self.output_format = output_format
        self.trainable = trainable  # TODO make kernel a trainable parameter
        self.stride = (1, hop_length)
        # self.padding = 0  # "same"

        self._scale_minimum = self.compute_minimum_scale()

        self.signal_length = None
        self._channels = None

        self._scales = None
        self._kernel = None
        self._kernel_real = None
        self._kernel_imag = None

    def compute_optimal_scales(self):
        """
        Determines the optimal scale distribution (see. Torrence & Combo, Eq. 9-10).
        :return: np.ndarray, collection of scales
        """
        if self.signal_length is None:
            raise ValueError(
                "Please specify signal_length before computing optimal scales."
            )
        J = int(
            (1 / self.dj) * np.log2(self.signal_length * self.dt / self._scale_minimum)
        )
        scales = self._scale_minimum * 2 ** (self.dj * np.arange(0, J + 1))

        # Remove high and low frequencies
        frequencies = np.array([1 / self.wavelet.fourier_period(s) for s in scales])
        if self.fmin:
            frequencies = frequencies[frequencies >= self.fmin]
            scales = scales[0 : len(frequencies)]
        if self.fmax:
            frequencies = frequencies[frequencies <= self.fmax]
            scales = scales[len(scales) - len(frequencies) : len(scales)]

        return scales

    def compute_minimum_scale(self):
        """
        Choose s0 so that the equivalent Fourier period is 2 * dt.
        See Torrence & Combo Sections 3f and 3h.
        :return: float, minimum scale level
        """
        dt = self.dt

        def func_to_solve(s):
            return self.wavelet.fourier_period(s) - 2 * dt

        return optimize.fsolve(func_to_solve, 1)[0]

    def _build_filters(self):
        self._filters = []
        for scale_idx, scale in enumerate(self._scales):
            # Number of points needed to capture wavelet
            M = 10 * scale / self.dt
            # Times to use, centred at zero
            t = torch.arange((-M + 1) / 2.0, (M + 1) / 2.0) * self.dt
            if len(t) % 2 == 0:
                t = t[0:-1]  # requires odd filter size
            # Sample wavelet and normalise
            norm = (self.dt / scale) ** 0.5
            filter_ = norm * self.wavelet(t, scale)
            self._filters.append(torch.conj(torch.flip(filter_, [-1])))

        self._pad_filters()

    def _pad_filters(self):
        filter_len = self._filters[-1].shape[0]
        padded_filters = []

        for f in self._filters:
            pad = (filter_len - f.shape[0]) // 2
            padded_filters.append(nn.functional.pad(f, (pad, pad)))

        self._filters = padded_filters

    def _build_wavelet_bank(self):
        """This function builds a 2D wavelet filter using wavelets at different scales

        Returns:
            tensor: Tensor of shape (num_widths, 1, channels, filter_len)
        """
        self._build_filters()
        wavelet_bank = torch.stack(self._filters)
        wavelet_bank = wavelet_bank.view(
            wavelet_bank.shape[0], 1, 1, wavelet_bank.shape[1]
        )
        # See comment by tez6c32
        # https://www.kaggle.com/anjum48/continuous-wavelet-transform-cwt-in-pytorch/comments#1499878
        # wavelet_bank = torch.cat([wavelet_bank] * self.channels, 2)
        return wavelet_bank

    def forward(self, x):
        """Compute CWT arrays from a batch of multi-channel inputs

        Args:
            x (torch.tensor): Tensor of shape (batch_size, channels, time)

        Returns:
            torch.tensor: Tensor of shape (batch_size, channels, widths, time)
        """
        if self.signal_length is None:
            self.signal_length = x.shape[-1]
            self.channels = x.shape[-2]
            self._scales = self.compute_optimal_scales()
            self._kernel = self._build_wavelet_bank()

            if self._kernel.is_complex():
                self._kernel_real = self._kernel.real
                self._kernel_imag = self._kernel.imag

        x = x.unsqueeze(1)

        if self._kernel.is_complex():
            if (
                x.dtype != self._kernel_real.dtype
                or x.device != self._kernel_real.device
            ):
                self._kernel_real = self._kernel_real.to(device=x.device, dtype=x.dtype)
                self._kernel_imag = self._kernel_imag.to(device=x.device, dtype=x.dtype)

            # Strides > 1 not yet supported for "same" padding
            # output_real = nn.functional.conv2d(
            #     x, self._kernel_real, padding=self.padding, stride=self.stride
            # )
            # output_imag = nn.functional.conv2d(
            #     x, self._kernel_imag, padding=self.padding, stride=self.stride
            # )
            output_real = conv2d_same(x, self._kernel_real, stride=self.stride)
            output_imag = conv2d_same(x, self._kernel_imag, stride=self.stride)
            output_real = torch.transpose(output_real, 1, 2)
            output_imag = torch.transpose(output_imag, 1, 2)

            if self.output_format == "Magnitude":
                return torch.sqrt(output_real ** 2 + output_imag ** 2)
            else:
                return torch.stack([output_real, output_imag], -1)

        else:
            if x.device != self._kernel.device or x.dtype != self._kernel.dtype:
                self._kernel = self._kernel.to(device=x.device, dtype=x.dtype)

            # output = nn.functional.conv2d(
            #     x, self._kernel, padding=self.padding, stride=self.stride
            # )
            output = conv2d_same(x, self._kernel, stride=self.stride)
            return torch.transpose(output, 1, 2)


## 3. Global Variables

In [None]:
# audio file
KAGGLE_TRAIN = '/kaggle/input/birdclef-2024/train_audio'
ADDED_TRAIN = '/kaggle/input/birdclef2024-additional-mp3/additional_audio'
ADDED_TRAIN_1 = '/kaggle/input/birdclef2024-additional-wav-1/additional_audio-1'
ADDED_TRAIN_2 = '/kaggle/input/birdclef2024-additional-wav-2/additional_audio-2'
TEST_SOUNDSCAPE = CFG.TEST_SOUNDSCAPE
os.makedirs(KAGGLE_TRAIN, exist_ok=True)

# to save image from audio
SAVE_TRAIN = '/kaggle/working/train_image'
SAVE_TEST = '/kaggle/working/test_image'

# image input
TRAIN_IMAGE = Path('/kaggle/input/bird2024-spec-v6/train_image/spec')
TEST_IMAGE = Path('/kaggle/working/test_image/spec')

# model trained
TRAINED_MODEL = Path('/kaggle/input/birdcref-2024-introduction-withtraining-train-v2')


## 4. Preprocessing

In [None]:
class preprocessing():
    def __init__(self, AUDIO_DIRECTORY, SAVE_DIRECTORY, view=CFG.VIEW, view_raw=CFG.VIEW_RAW, test=CFG.test):
        # config
        self.AUDIO_DIRECTORY = AUDIO_DIRECTORY
        self.SAVE_DIRECTORY = SAVE_DIRECTORY
        self.view = view
        self.view_raw = view_raw
        self.test = test
        
        # make directory
        make_directory = True
        if make_directory:
            func_names = [method for method in dir(self) if callable(getattr(self, method)) and method.startswith("func")]
            print(func_names)
            os.makedirs(self.SAVE_DIRECTORY, exist_ok=True)
            for func_name in func_names:
                func = func_name.split('_')[-1]
                os.makedirs(self.SAVE_DIRECTORY + '/' + func, exist_ok=True)
    
    def load_wave(self, audio_filepath, offset=0, duration=5, desired_length=960000)
        # load audio
        self.y, _ = librosa.load(audio_filepath, sr=CFG.SR , offset=offset, duration=30)
        while self.y.shape[0] < desired_length:
            self.y = np.concatenate((self.y, self.y))
            if self.y.shape[0] > desired_length:
                self.y = self.y[:desired_length]
        self.sr = CFG.SR
    
    def normalize(self, data: np.ndarray):
        if data.dtype == np.uint8 and data.min() >= 0 and data.max() <= 255:
            return data
        data = data.astype(np.single)
        # Normalize 0 to min
        data = data - data.min()
        # Normalize 0 to 255
        data = (data / data.max() * 255).astype(np.uint8)
        
        return data
    
    # apply procesing with save
    def apply_func(self, function):
        
        # get the type of processing
        species_list = os.listdir(self.AUDIO_DIRECTORY)
        
        # when train
        if not self.test:
            
            for species in species_list:
                species_path = self.AUDIO_DIRECTORY + '/' + species
                audio_file_list = os.listdir(species_path)
                
                for audio_count, audio_file in enumerate(audio_file_list):
                    audio_filepath = species_path + '/' + audio_file
                    n_load = int(CFG.DURATION/CFG.LENGTH)
                    
                    for i in range(n_load):
                        # load audio
                        self.load_wave(audio_filepath, offset=int(i*CFG.LENGTH), duration=CFG.LENGTH) 
                        # apply function
                        output = function() 
                        output = self.normalize(output)

                        SAVE_DIRECTORY = Path(self.SAVE_DIRECTORY + '/second_30')
                        SAVE_DIRECTORY.mkdir(exist_ok=True)
                        SAVE_PATH = SAVE_DIRECTORY / f"{audio_file.split('.')[0]}.npy"
                        # save as pickle with gzip
                        self.save_as_pickle_gzip(output, SAVE_PATH)
                        del output
                        
                    if CFG.ONLY_ONCE:
                        break
                if CFG.ONLY_ONCE:
                    break
        
        # when test
        if self.test:
            audio_directory_path = self.AUDIO_DIRECTORY
            audio_length = int(4*60) # second
            audio_offset_unit_max = int(audio_length / 5)
            for audio_file_path in audio_directory_path.glob('*.ogg'):
                    for audio_offset_unit in range(audio_offset_unit_max):
                        audio_offset = audio_offset_unit * 5       
                        # load audio                 
                        self.load_wave(str(audio_file_path), audio_offset) 
                        # apply function
                        output = function() 
                        output = self.normalize(output)
                        
                        # save
                        SAVE_DIRECTORY = Path(self.SAVE_DIRECTORY) / function.__name__.split('_')[-1]
                        SAVE_DIRECTORY.mkdir(exist_ok=True)
                        if CFG.NO_SOUNDSCAPE:
                            SAVE_PATH = SAVE_DIRECTORY  / f"{audio_file_path.stem.replace('XC','')}_{audio_offset+5}.npy" # [soundscape_id]_[end_time].npy
                        else:
                            SAVE_PATH = SAVE_DIRECTORY  / f"{audio_file_path.stem.replace('soundscape_','')}_{audio_offset+5}.npy" # [soundscape_id]_[end_time].npy
                        np.save(SAVE_PATH, output)
                        del output
                        
                        
                        if CFG.ONLY_ONCE:
                            break
                    if CFG.ONLY_ONCE:
                        break
          
        
    def save_as_pickle_gzip(self, data, filepath):       
        with gzip.open(str(filepath) + '.gz', 'wb') as f:
            pickle.dump(data, f)
            
    
    # as is
    def func_waveform(self):
        if self.view:
            print('waveform shape: ', self.y.shape)
            print('waveform type: ', type(self.y))
            print('waveform value type: ', type(self.y[0]))
            print('waveform shape type: ', type(self.y.shape))
            print('waveform shape value type: ', type(self.y.shape[0]))
            display(Audio(self.y, rate=self.sr))
            plt.figure(figsize=(10, 4))
            librosa.display.waveshow(self.y, sr=self.sr)
            plt.title('Waveform')
            plt.xlabel('Time (s)')
            plt.ylabel('Amplitude')
            plt.show()
        return self.y
    
    # spectrogram
    def func_spec(self):
        spec = librosa.amplitude_to_db(np.abs(librosa.stft(self.y)), ref=np.max)
        min_ = spec.min()
        max_ = spec.max()
        if max_ != min_:
            spec = (spec - min_)/(max_ - min_)
        
        if self.view:
            print('spec shape: ', spec.shape)
            plt.figure(figsize=(10, 4))
            librosa.display.specshow(spec, sr=self.sr, x_axis='time', y_axis='log')
            plt.colorbar(format='%+2.0f dB')
            plt.title('Spectrogram')
            plt.show()
            
        return spec
    
    # melspectrogram
    def func_melspec(self, normalization=False):
        melspec = librosa.feature.melspectrogram(
            y=self.y, 
            sr=CFG.SR,                  # sample rate
            n_fft=CFG.N_FFT,            # number of samples in window 
            hop_length=CFG.HOP_LENGTH,  # step size of window
            n_mels=CFG.MELSPEC_H,       # horizontal resolution from fmin→fmax in log scale
            fmin=CFG.fmin,                    # minimum frequency
            fmax=CFG.fmax,                 # maximum frequency
            power=2.0,                  # intensity^power for log scale
        )
        melspec = librosa.power_to_db(melspec, ref=np.max)
        
        if self.view:
            print('melspec shape: ', melspec.shape)
            plt.figure(figsize=(10, 4))
            librosa.display.specshow(melspec, sr=self.sr, x_axis='time', y_axis='mel')
            plt.colorbar(format='%+2.0f dB')
            plt.title('Mel Spectrogram')
            plt.show()
            
        return melspec
    
    # scalogram
    def func_scalogram(self):

        sig_pt = torch.tensor(self.normalize(self.y), dtype=torch.float32)
        sig_pt = torch.stack([sig_pt] * 1)  # 3 channels
        sig_pt = torch.stack([sig_pt] * 1)  # Batch size of 32

        pycwt = CWT2(dt=1/1000)
        start =  time()
        out = pycwt(sig_pt)
        end = time()
        
        if self.view:
            print('scarogram shape: ', out.shape)
            plt.figure(figsize=(10, 4))
            plt.imshow(abs(out[0,0]), aspect='auto', extent=[0, len(self.y) / self.sr, 1, 100], cmap='jet', origin='lower')
            plt.colorbar()
            plt.title('Scalogram')
            plt.xlabel('Time (s)')
            plt.ylabel('Scale')

        return abs(out[0, 0].to('cpu').detach().numpy().copy())

    # chromagram
    def func_chromagram(self):
        C = librosa.feature.chroma_cqt(y=self.y, sr=self.sr)
        
        if self.view:
            print('chromagram shape: ', C.shape)
            plt.figure(figsize=(10, 4))
            librosa.display.specshow(C, sr=self.sr, x_axis='time', y_axis='chroma', cmap='coolwarm')
            plt.colorbar()
            plt.title('Chromagram')
            plt.show()
            
        return C

    # Mel-Frequency Cepstral Coefficients(mfcc)
    def func_mfcc(self): 
        mfcc = librosa.feature.mfcc(y=self.y, sr=self.sr)
        
        if self.view:
            print('mfcc shape: ', mfcc.shape)
            plt.figure(figsize=(10, 4))
            librosa.display.specshow(mfcc, sr=self.sr, x_axis='time')
            plt.ylabel('MFCC coeffs')
            plt.colorbar()
            plt.title('MFCC')
            plt.show()
            
        return mfcc

    # spectral contrast
    def func_spectralcontrast(self):
        contrast = librosa.feature.spectral_contrast(y=self.y, sr=self.sr)
        
        if self.view:
            print('contrast shape: ', contrast.shape)
            plt.figure(figsize=(10, 4))
            librosa.display.specshow(contrast, x_axis='time')
            plt.colorbar()
            plt.ylabel('Frequency bands')
            plt.title('Spectral Contrast')
            plt.show()
            
        return contrast
    
    # melspec at top and scalo at under
    def func_melspecscalo(self):
        melspec = self.func_melspec()
        melspec = self.normalize(melspec)
        
        scalo = self.func_scalogram()
        scalo = self.normalize(scalo)
        size_align = (int(62.6*CFG.LENGTH), 128)
        scalo = cv2.resize(scalo, size_align, interpolation=cv2.INTER_AREA)
        
        melspec_scalo = np.concatenate([melspec, scalo], axis=0)

        return melspec_scalo
    
    
    def execute(self):
        # select the processing
        func_list = [
            self.func_waveform,
#             self.func_spec,
#             self.func_melspec,
#             self.func_scalogram,
#             self.func_chromagram,
#             self.func_mfcc,
#             self.func_spectralcontrast,
#             self.func_melspecscalo,
        ]
        for func in func_list:
            self.apply_func(func)

# ・ Define preprocessing class
preprocessing_kaggle = preprocessing(KAGGLE_TRAIN, SAVE_TRAIN)
# preprocessing_added_train = preprocessing(ADDED_TRAIN, SAVE_TRAIN)
# preprocessing_added_train_1 = preprocessing(ADDED_TRAIN_1, SAVE_TRAIN)
# preprocessing_added_train_2 = preprocessing(ADDED_TRAIN_2, SAVE_TRAIN)
# preprocessing_test = preprocessing(TEST_SOUNDSCAPE, SAVE_TEST)

t1 = time()
# ・ Execute preprocessing
preprocessing_kaggle.execute()
# preprocessing_added_train.execute()
# preprocessing_added_train_1.execute()
# preprocessing_added_train_2.execute()
# preprocessing_test.execute()
t2 = time()
print('Preprocessing time: ', f"{(t2-t1)/60}m")

END