In [1]:
import torch
import torchaudio
from IPython.display import Audio, display

from s3prl.downstream.augment_utils.tempo_perturbation import TempoPerturbation
from s3prl.downstream.mdd.dataset import L2ArcticDataset

2024-04-04 13:48:14.164293: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-04 13:48:14.196581: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
dataset = L2ArcticDataset(
    'train',
    '/home/xt0r3-user/cambridge/partii/dissertation/s3prl/data/l2arctic_release_v5.0',
    '/home/xt0r3-user/cambridge/partii/dissertation/s3prl/s3prl/s3prl/downstream/mdd/data/'
)

Skipping malformatted TextGrid file: /home/xt0r3-user/cambridge/partii/dissertation/s3prl/data/l2arctic_release_v5.0/YDCK/annotation/arctic_a0209.TextGrid


In [3]:
def play_audio(waveform, sample_rate):
  waveform = waveform.numpy()
  display(Audio(waveform, rate=sample_rate))

In [4]:
audio = torchaudio.functional.resample(dataset[0][0], orig_freq=44100, new_freq=16000)


play_audio(audio, 16000)

In [5]:
import numpy as np
from scipy.interpolate import interp1d

def _validate_audio(audio):
    """validate the input audio and modify the order of channels.

    Parameters
    ----------

    audio : numpy.ndarray [shape=(channel, num_samples) or (num_samples)\
                           or (num_samples, channel)]
            the input audio sequence to validate.

    Returns
    -------

    audio : numpy.ndarray [shape=(channel, num_samples)]
            the validataed output audio sequence.
    """
    if audio.ndim == 1:
        audio = np.expand_dims(audio, 0)
    elif audio.ndim > 2:
        raise Exception("Please use the valid audio source. "
                        + "Number of dimension of input should be less than 3.")
    elif audio.shape[0] > audio.shape[1]:
        print('it seems that the 2nd axis of the input audio source '
             + 'is a channel. it is recommended that fix channel '
             + 'to the 1st axis.', stacklevel=3)
        audio = audio.T

    return audio


def _validate_scale_factor(audio, s):
    """Validate the scale factor s and
    convert the fixed scale factor to anchor points.

    Parameters
    ----------

    audio : numpy.ndarray [shape=(num_channels, num_samples) \
                           or (num_samples) or (num_samples, num_channels)]
            the input audio sequence.
    s : number > 0 [scalar] or numpy.ndarray [shape=(2, num_points) \
        or (num_points, 2)]
        the time stretching factor. Either a constant value (alpha)
        or an (2 x n) (or (n x 2)) array of anchor points
        which contains the sample points of the input signal in the first row
        and the sample points of the output signal in the second row.

    Returns
    -------

    anc_points : numpy.ndarray [shape=(2, num_points)]
                 anchor points which contains the sample points
                 of the input signal in the first row
                 and the sample points of the output signal in the second row.
    """
    if np.isscalar(s):
        anc_points = np.array([[0, np.shape(audio)[1] - 1],
                               [0, np.ceil(s * np.shape(audio)[1]) - 1]])
    elif s.ndim == 2:
        if s.shape[0] == 2:
            anc_points = s
        elif s.shape[1] == 2:
            print('it seems that the anchor points '
                 + 'has shape (num_points, 2). '
                 + 'it is recommended to '
                 + 'have shape (2, num_points).', stacklevel=3)
            anc_points = s.T
    else:
        raise Exception('Please use the valid anchor points. '
                        + '(scalar or pair of input/output sample points)')

    return anc_points

def wsola(x, s, win_type='hann',
          win_size=1024, syn_hop_size=512, tolerance=512):
    """Modify length of the audio sequence using WSOLA algorithm.

    Parameters
    ----------

    x : numpy.ndarray [shape=(channel, num_samples) or (num_samples)]
        the input audio sequence to modify.
    s : number > 0 [scalar] or numpy.ndarray [shape=(2, num_points)]
        the time stretching factor. Either a constant value (alpha)
        or an 2 x n array of anchor points which contains the sample points
        of the input signal in the first row
        and the sample points of the output signal in the second row.
    win_type : str
               type of the window function. hann and sin are available.
    win_size : int > 0 [scalar]
               size of the window function.
    syn_hop_size : int > 0 [scalar]
                   hop size of the synthesis window.
                   Usually half of the window size.
    tolerance : int >= 0 [scalar]
                number of samples the window positions
                in the input signal may be shifted
                to avoid phase discontinuities when overlap-adding them
                to form the output signal (given in samples).

    Returns
    -------

    y : numpy.ndarray [shape=(channel, num_samples) or (num_samples)]
        the modified output audio sequence.
    """
    # validate the input audio and scale factor.
    x = _validate_audio(x)
    anc_points = _validate_scale_factor(x, s)

    n_chan = x.shape[0]
    output_length = int(anc_points[-1, -1]) + 1

    display(f"{output_length}")



    win = torch.hann_window(win_size).numpy()

    sw_pos = np.arange(0, output_length + win_size // 2, syn_hop_size)
    ana_interpolated = interp1d(anc_points[1, :], anc_points[0, :],
                                fill_value='extrapolate')
    aw_pos = np.round(ana_interpolated(sw_pos)).astype(int)

    display(f"{sw_pos=}")
    display(f"{sw_pos.shape=}")
    display(f"{aw_pos=}")
    display(f"{aw_pos.shape=}")

    ana_hop = np.insert(aw_pos[1:] - aw_pos[0: -1], 0, 0)

    display(f"{ana_hop=}")

    y = np.zeros((n_chan, output_length))

    min_fac = np.min(syn_hop_size / ana_hop[1:])

    # padding the input audio sequence.
    left_pad = int(win_size // 2 + tolerance)
    right_pad = int(np.ceil(1 / min_fac) * win_size + tolerance)
    x_padded = np.pad(x, ((0, 0), (left_pad, right_pad)), 'constant')

    aw_pos = aw_pos + tolerance

    # Applying WSOLA to each channels
    for c, x_chan in enumerate(x_padded):
        y_chan = np.zeros(output_length + 2 * win_size)
        ow = np.zeros(output_length + 2 * win_size)

        delta = 0

        for i in range(len(aw_pos) - 1):
            x_adj = x_chan[aw_pos[i] + delta: aw_pos[i] + win_size + delta]
            y_chan[sw_pos[i]: sw_pos[i] + win_size] += x_adj * win
            ow[sw_pos[i]: sw_pos[i] + win_size] += win

            nat_prog = x_chan[aw_pos[i] + delta + syn_hop_size:
                              aw_pos[i] + delta + syn_hop_size + win_size]

            next_aw_range = np.arange(aw_pos[i+1] - tolerance,
                                      aw_pos[i+1] + win_size + tolerance)

            x_next = x_chan[next_aw_range]

            cross_corr = np.correlate(nat_prog, x_next)
            print(f"{cross_corr=}")
            print(f"{cross_corr.shape=}")

            max_index = np.argmax(cross_corr)

            delta = tolerance - max_index

        # Calculate last frame
        x_adj = x_chan[aw_pos[-1] + delta: aw_pos[-1] + win_size + delta]
        y_chan[sw_pos[-1]: sw_pos[-1] + win_size] += x_adj * win
        ow[sw_pos[-1]: sw_pos[-1] + win_size] += + win

        ow[ow < 1e-3] = 1

        y_chan = y_chan / ow
        y_chan = y_chan[win_size // 2:]
        y_chan = y_chan[: output_length]

        y[c, :] = y_chan

    return y.squeeze()

In [12]:
display(Audio(wsola(audio, 1.2), rate=16000))

'94464'

'sw_pos=array([    0,   512,  1024,  1536,  2048,  2560,  3072,  3584,  4096,\n        4608,  5120,  5632,  6144,  6656,  7168,  7680,  8192,  8704,\n        9216,  9728, 10240, 10752, 11264, 11776, 12288, 12800, 13312,\n       13824, 14336, 14848, 15360, 15872, 16384, 16896, 17408, 17920,\n       18432, 18944, 19456, 19968, 20480, 20992, 21504, 22016, 22528,\n       23040, 23552, 24064, 24576, 25088, 25600, 26112, 26624, 27136,\n       27648, 28160, 28672, 29184, 29696, 30208, 30720, 31232, 31744,\n       32256, 32768, 33280, 33792, 34304, 34816, 35328, 35840, 36352,\n       36864, 37376, 37888, 38400, 38912, 39424, 39936, 40448, 40960,\n       41472, 41984, 42496, 43008, 43520, 44032, 44544, 45056, 45568,\n       46080, 46592, 47104, 47616, 48128, 48640, 49152, 49664, 50176,\n       50688, 51200, 51712, 52224, 52736, 53248, 53760, 54272, 54784,\n       55296, 55808, 56320, 56832, 57344, 57856, 58368, 58880, 59392,\n       59904, 60416, 60928, 61440, 61952, 62464, 62976, 63488, 64000,

'sw_pos.shape=(186,)'

'aw_pos=array([    0,   427,   853,  1280,  1707,  2133,  2560,  2987,  3413,\n        3840,  4267,  4693,  5120,  5547,  5973,  6400,  6827,  7253,\n        7680,  8107,  8533,  8960,  9387,  9813, 10240, 10667, 11093,\n       11520, 11947, 12373, 12800, 13227, 13653, 14080, 14507, 14933,\n       15360, 15787, 16213, 16640, 17067, 17493, 17920, 18347, 18773,\n       19200, 19627, 20053, 20480, 20907, 21333, 21760, 22187, 22613,\n       23040, 23467, 23893, 24320, 24747, 25173, 25600, 26027, 26453,\n       26880, 27307, 27733, 28160, 28587, 29013, 29440, 29867, 30293,\n       30720, 31147, 31573, 32000, 32427, 32853, 33280, 33707, 34133,\n       34560, 34987, 35413, 35840, 36267, 36693, 37120, 37547, 37973,\n       38400, 38827, 39253, 39680, 40107, 40533, 40960, 41387, 41813,\n       42240, 42667, 43093, 43520, 43947, 44373, 44800, 45227, 45653,\n       46080, 46507, 46933, 47360, 47787, 48213, 48640, 49067, 49493,\n       49920, 50347, 50773, 51200, 51627, 52053, 52480, 52907, 53333,

'aw_pos.shape=(186,)'

'ana_hop=array([  0, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427,\n       427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427,\n       426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426,\n       427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427,\n       427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427,\n       426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426,\n       427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427,\n       427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427,\n       426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426,\n       427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427,\n       427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427,\n       426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426,\n       427, 427, 426, 427, 427, 426, 427, 427, 426, 427, 427, 426, 427,\n       427, 426, 427, 427, 426, 427, 427, 

cross_corr=array([-0.00461966, -0.00510644, -0.00564242, ..., -0.00585969,
       -0.00585083, -0.00583607], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([ 0.01339211,  0.01396581,  0.01486445, ..., -0.00721942,
       -0.00731433, -0.00741667], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([-0.07324108, -0.07366276, -0.07474163, ...,  0.02984602,
        0.03050624,  0.0307083 ], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([-0.09272612, -0.08768628, -0.08226316, ..., -0.00137451,
       -0.00056305,  0.00038463], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([-2.7806675e+00, -4.3720245e+00, -5.7539091e+00, ...,
       -1.0144569e-02, -8.2860142e-04,  7.1138432e-03], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([ -6.2453957 ,  -8.433781  , -10.526524  , ...,   0.12896569,
         0.11623617,   0.1037119 ], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([-2.0560644 , -0.8228899 ,  0.40082404, ...,  0.15092978,
       

In [11]:
augmentation = TempoPerturbation(
    N=1024,
    sample_rate=16000,
    alpha_high=1.2,
    alpha_low=1.2,
    delta_max=512,
    prob=1.0,
)
play_audio(
    augmentation(
        [audio]
    )[0],
    sample_rate=16000,
)

new_len=94464
cross_corr=tensor([-0.0046, -0.0051, -0.0056,  ..., -0.0059, -0.0059, -0.0058])
cross_corr.shape=torch.Size([1025])
cross_corr=tensor([ 0.0134,  0.0140,  0.0149,  ..., -0.0072, -0.0073, -0.0074])
cross_corr.shape=torch.Size([1025])
cross_corr=tensor([-0.0732, -0.0737, -0.0747,  ...,  0.0298,  0.0305,  0.0307])
cross_corr.shape=torch.Size([1025])
cross_corr=tensor([-0.0927, -0.0877, -0.0823,  ..., -0.0014, -0.0006,  0.0004])
cross_corr.shape=torch.Size([1025])
cross_corr=tensor([-2.7807e+00, -4.3720e+00, -5.7539e+00,  ..., -1.0145e-02,
        -8.2858e-04,  7.1138e-03])
cross_corr.shape=torch.Size([1025])
cross_corr=tensor([ -6.2454,  -8.4338, -10.5265,  ...,   0.1290,   0.1162,   0.1037])
cross_corr.shape=torch.Size([1025])
cross_corr=tensor([-2.0561, -0.8229,  0.4008,  ...,  0.1509, -0.0137, -0.1794])
cross_corr.shape=torch.Size([1025])
cross_corr=tensor([-16.4827, -18.6570, -20.2342,  ...,   4.5099,   5.0873,   5.6311])
cross_corr.shape=torch.Size([1025])
cross_corr=ten

In [8]:
a = np.arange(1, 6)
v = np.arange(6, 8)
np.correlate(a, v)

array([20, 33, 46, 59])

In [9]:
x = torch.tensor(a, dtype=int)
w = torch.tensor(v, dtype=int)

augmentation.cross_correlate(w, x)

tensor([59, 46, 33, 20])

In [10]:
wsola(audio, 0.7) - augmentation(
    [audio]
)[0].numpy().max()

'55104'

'sw_pos=array([    0,   512,  1024,  1536,  2048,  2560,  3072,  3584,  4096,\n        4608,  5120,  5632,  6144,  6656,  7168,  7680,  8192,  8704,\n        9216,  9728, 10240, 10752, 11264, 11776, 12288, 12800, 13312,\n       13824, 14336, 14848, 15360, 15872, 16384, 16896, 17408, 17920,\n       18432, 18944, 19456, 19968, 20480, 20992, 21504, 22016, 22528,\n       23040, 23552, 24064, 24576, 25088, 25600, 26112, 26624, 27136,\n       27648, 28160, 28672, 29184, 29696, 30208, 30720, 31232, 31744,\n       32256, 32768, 33280, 33792, 34304, 34816, 35328, 35840, 36352,\n       36864, 37376, 37888, 38400, 38912, 39424, 39936, 40448, 40960,\n       41472, 41984, 42496, 43008, 43520, 44032, 44544, 45056, 45568,\n       46080, 46592, 47104, 47616, 48128, 48640, 49152, 49664, 50176,\n       50688, 51200, 51712, 52224, 52736, 53248, 53760, 54272, 54784,\n       55296])'

'sw_pos.shape=(109,)'

'aw_pos=array([    0,   731,  1463,  2194,  2926,  3657,  4389,  5120,  5851,\n        6583,  7314,  8046,  8777,  9509, 10240, 10971, 11703, 12434,\n       13166, 13897, 14629, 15360, 16092, 16823, 17554, 18286, 19017,\n       19749, 20480, 21212, 21943, 22674, 23406, 24137, 24869, 25600,\n       26332, 27063, 27794, 28526, 29257, 29989, 30720, 31452, 32183,\n       32914, 33646, 34377, 35109, 35840, 36572, 37303, 38034, 38766,\n       39497, 40229, 40960, 41692, 42423, 43155, 43886, 44617, 45349,\n       46080, 46812, 47543, 48275, 49006, 49737, 50469, 51200, 51932,\n       52663, 53395, 54126, 54857, 55589, 56320, 57052, 57783, 58515,\n       59246, 59977, 60709, 61440, 62172, 62903, 63635, 64366, 65097,\n       65829, 66560, 67292, 68023, 68755, 69486, 70218, 70949, 71680,\n       72412, 73143, 73875, 74606, 75338, 76069, 76800, 77532, 78263,\n       78995])'

'aw_pos.shape=(109,)'

'ana_hop=array([  0, 731, 732, 731, 732, 731, 732, 731, 731, 732, 731, 732, 731,\n       732, 731, 731, 732, 731, 732, 731, 732, 731, 732, 731, 731, 732,\n       731, 732, 731, 732, 731, 731, 732, 731, 732, 731, 732, 731, 731,\n       732, 731, 732, 731, 732, 731, 731, 732, 731, 732, 731, 732, 731,\n       731, 732, 731, 732, 731, 732, 731, 732, 731, 731, 732, 731, 732,\n       731, 732, 731, 731, 732, 731, 732, 731, 732, 731, 731, 732, 731,\n       732, 731, 732, 731, 731, 732, 731, 732, 731, 732, 731, 731, 732,\n       731, 732, 731, 732, 731, 732, 731, 731, 732, 731, 732, 731, 732,\n       731, 731, 732, 731, 732])'

cross_corr=array([0.02619893, 0.02530554, 0.02475689, ..., 0.0081623 , 0.00789707,
       0.00762186], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([ 0.03384878,  0.03525769,  0.03661398, ..., -0.00230931,
       -0.00313931, -0.00378975], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([-0.00739006, -0.0168961 , -0.02693434, ..., -0.00319701,
       -0.00216035, -0.00180449], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([17.446142 , 16.282038 , 14.822403 , ...,  1.0346048,  0.3061338,
       -0.4026743], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([-5.106022  , -4.231931  , -3.2152295 , ..., -3.7088046 ,
       -2.3615537 , -0.99671185], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([4.945966 , 5.242058 , 5.3545084, ..., 6.7825403, 6.2591376,
       5.4482927], dtype=float32)
cross_corr.shape=(1025,)
cross_corr=array([ 1.040566 ,  0.9338232,  0.810479 , ..., -8.744036 , -8.830133 ,
       -8.682391 ], dtype=float32)
cross_corr.s

array([-0.41186122, -0.41181434, -0.41217723, ..., -0.43052027,
       -0.42961884, -0.42992475])