In [None]:
import jukemirlib
import torch
import numpy as np

gpu_id = 1
torch.cuda.set_device(gpu_id)
fpath = 'data/magnatagatune/f/american_bach_soloists-j_s__bach_solo_cantatas-01-bwv54__i_aria-291-320.wav'
audio = jukemirlib.load_audio(fpath)
audio = torch.tensor(audio)
audio = audio.__array__()
reps = jukemirlib.extract(audio, layers=[36])

In [None]:
print(len(audio)) # 1284369
print(torch.tensor(reps[36]))

In [None]:
from clmr.datasets import get_dataset
from clmr.data import ContrastiveDataset
from torchaudio_augmentations import (
    RandomApply,
    ComposeMany,
    RandomResizedCrop,
    PolarityInversion,
    Noise,
    Gain,
    HighLowPass,
    Delay,
    PitchShift,
    Reverb,
)
from torch.utils.data import DataLoader
train_dataset = get_dataset(
    'magnatagatune', './data', subset="train", download=False)
audio_length = 59049
transforms_polarity = 0.8
transforms_noise = 0.01
transforms_gain = 0.3
transforms_filters = 0.8
transforms_delay = 0.3
transforms_pitch = 0.6
transforms_reverb = 0.6
load_dataset = True
sample_rate = 22050
train_transform = [
    RandomResizedCrop(n_samples=audio_length),
    RandomApply([PolarityInversion()], p=transforms_polarity),
    RandomApply([Noise()], p=transforms_noise),
    RandomApply([Gain()], p=transforms_gain),
    RandomApply(
        [HighLowPass(sample_rate=sample_rate)], p=transforms_filters
    ),
    RandomApply([Delay(sample_rate=sample_rate)], p=transforms_delay),
    RandomApply(
        [
            PitchShift(
                n_samples=audio_length,
                sample_rate=sample_rate,
            )
        ],
        p=transforms_pitch,
    ),
    RandomApply(
        [Reverb(sample_rate=sample_rate)], p=transforms_reverb
    ),
]
num_augmented_samples = 2
contrastive_train_dataset = ContrastiveDataset(
    train_dataset,
    input_shape=(1, audio_length),
    transform=ComposeMany(
        train_transform, num_augmented_samples=2
    ),
)
train_loader = DataLoader(
        contrastive_train_dataset,
        batch_size=12,
        num_workers=16,
        persistent_workers=True,
        drop_last=True,
        shuffle=True,
    )