In [1]:
import os
import sys

PROJECT_ROOT = os.path.abspath(os.path.join(
                  os.path.dirname("test-mode2"), 
                  os.pardir)
)
sys.path.append(PROJECT_ROOT)

import torch
import torch.nn as nn
import pywt
import numpy as np
from copy import deepcopy
from src.utils import read_feature, pad_features
from src.features import extract_mfcc, extract_melspectrogram

In [8]:
features_path = "../features/propor2022/"
feature = "mel_spectrogram"

# loading training features
X_train = read_feature(path=features_path, fold="0", name="X_train.pth")
y_train = read_feature(path=features_path, fold="0", name="y_train.pth")
print(f"Train: {X_train.shape}, {y_train.shape}")

# loading validation features
X_valid = read_feature(path=features_path, fold="0", name="X_valid.pth")
y_valid = read_feature(path=features_path, fold="0", name="y_valid.pth")
print(f"Valid: {X_valid.shape}, {y_valid.shape}")

# loading testing features
X_test = read_feature(path=features_path, fold=None, name="X_test.pth")
y_test = read_feature(path=features_path, fold=None, name="y_test.pth")
print(f"Test: {X_test.shape}, {y_test.shape}")

Train: torch.Size([500, 1, 128000]), torch.Size([500, 3])
Valid: torch.Size([125, 1, 128000]), torch.Size([125, 3])
Test: torch.Size([308, 1, 128000]), torch.Size([308, 3])


In [10]:
datas = []

for i in range(X_train.shape[0]):
    audio = deepcopy(X_train[i, :, :].detach().squeeze().numpy())

    coeffs = pywt.wavedec(
        data=audio,
        wavelet="db8",
        mode="symmetric",
        level=4
    )

    coeffs = [torch.from_numpy(c).unsqueeze(0) for c in coeffs]
    
    datas.append(coeffs)

for i in range(len(datas)):
    data = datas[i]
    feats = []
    
    for d in data:
        if feature == "mfcc":
            feat = extract_mfcc(
                audio=d,
                sample_rate=8000,
                n_fft=512,
                hop_length=256,
                n_mfcc=64,
                f_min=0,
                f_max=None
            )
        elif feature == "mel_spectrogram":
            feat = extract_melspectrogram(
                audio=d,
                sample_rate=8000,
                n_fft=512,
                hop_length=256,
                n_mels=128
            )
            
        feats.append(feat)
    
    # padding the mel spectrograms to be the same size
    max_height = max([x.size(1) for x in feats])
    max_width = max([x.size(2) for x in feats])

    feats = pad_features(
        features=feats,
        max_height=max_height,
        max_width=max_width
    )
    feats = torch.concat(feats, dim=0)
    feats = feats.permute(0, 2, 1) # time and frequency axis permutation
    
    print(feats.shape)

torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([

torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([5, 251, 128])
torch.Size([

In [11]:
feats = []

for i in range(X_train.shape[0]):
    audio = deepcopy(X_train[i, :, :].detach())

    if feature == "mfcc":
        feat = extract_mfcc(
            audio=d,
            sample_rate=8000,
            n_fft=512,
            hop_length=256,
            n_mfcc=64,
            f_min=0,
            f_max=None
        )
    elif feature == "mel_spectrogram":
        feat = extract_melspectrogram(
            audio=d,
            sample_rate=8000,
            n_fft=512,
            hop_length=256,
            n_mels=128
        )
    feat = feat.permute(0, 2, 1)
    feats.append(feat)

for i in range(len(feats)):
    spectrogram = feats[i].squeeze(0)
    
    coeffs = pywt.wavedec2(
        data=spectrogram,
        level=4,
        wavelet="db8", 
        mode="symmetric"
    )
    
    print(coeffs[0].shape)
    
    for i in range(1, len(coeffs)):
        for j in range(len(coeffs[i])):
            print(coeffs[i][j].shape)

    arr, coeffs = pywt.coeffs_to_array(coeffs)
    arr = torch.from_numpy(arr).unsqueeze(0)
    
    print(arr.shape)
    print()



(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71)
torch.Size([1, 309, 187])

(29, 22)
(29, 22)
(29, 22)
(29, 22)
(44, 29)
(44, 29)
(44, 29)
(74, 43)
(74, 43)
(74, 43)
(133, 71)
(133, 71)
(133, 71