In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !sudo conda install -c conda-forge sox -y
# !sudo apt-get install sox libsox-dev libsox-fmt-all -y
# !sudo conda install -c conda-forge librosa -y
# !sudo conda install -c pytorch torchaudio -y
# !pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html

In [3]:
import torch
import IPython
import librosa
import torchaudio
from tqdm import tqdm_notebook as tqdm

In [4]:
data_dir = "../dltraining/wavs"
train_meta_dir = "train_audio_meta.json"
valid_meta_dir = "valid_audio_meta.json"
full_meta_dir = "full_audio_meta.json"

In [10]:
import torch
from pathlib import Path
import torchaudio
import numpy as np
from torch.utils.data import Dataset, DataLoader
import json
import numpy as np
import os

class AudioDataset(Dataset):
    def __init__(self, root_dir, metadata_file, transform=None, isBalanced=False, num_sequences=3, fft_multiplier = 20, sequence_length=200, select_type="ordered", isValid=False):
        self.root_dir = Path(root_dir)
        self.metadata = json.load(open(metadata_file,'r'))
        self.list_videos = list(self.metadata.keys())
  
        
        self.isBalanced = isBalanced
        if isBalanced:
            self.fake_list = [key for key, val in self.metadata.items() if val['audio_label']=='FAKE']
            self.real_list = [key for key, val in self.metadata.items() if val['audio_label']!='FAKE']

        self.num_sequences = num_sequences
        self.sequence_length = sequence_length
        self.select_type = select_type
        self.transform = transform
        self.fft_multiplier = fft_multiplier
        
        if self.isBalanced:
            self.length = min(len(self.fake_list), len(self.real_list))
        else:
            self.length = len(self.metadata)
            
        self.isValid = isValid
        
    def init_workers_fn(self, worker_id):
        new_seed = int.from_bytes(os.urandom(4), byteorder='little')
        np.random.seed(new_seed)

    def collate_fn(self, samples):
        source_filenames, audios, labels, video_original_filenames = zip(*samples)
        return source_filenames, audios, torch.tensor(labels), video_original_filenames

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if self.isBalanced:
            choice = np.random.randint(2)
            if choice==0:
                video_choice = np.random.randint(len(self.fake_list))
                video_filename = self.fake_list[video_choice]
            else:
                video_choice = np.random.randint(len(self.real_list))
                video_filename = self.real_list[video_choice]
            video_filename = self.root_dir/video_filename
        else:
            if type(self.list_videos[idx]) == str:
                video_filename = Path(self.list_videos[idx])
            else:
                video_filename = self.list_videos[idx]
        
        source_filename = f"{video_filename.stem}.mp4"
        video_metadata = self.metadata[source_filename]
        sound_filename = f"{video_filename.stem}.wav"
        sound_filename = self.root_dir/sound_filename

        wave, sr = torchaudio.load(sound_filename)
        num_seconds = wave.shape[1]/sr
        spectrum = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=self.fft_multiplier*round(num_seconds))(wave.mean(axis=0))
        spectrum = (spectrum + 1e-9).log()
        s_mean, s_std = spectrum.mean(), spectrum.std()
        spectrum = (spectrum-s_mean) / s_std
        
        
        if self.transform is not None:
            spectrum = self.transform(spectrum)
        num_ff = self.sequence_length
        list_of_choices = []
        if self.isValid == False:
            for i in range(self.num_sequences):
                start_index = np.random.randint(spectrum.shape[1]-num_ff)
                list_of_choices.append(start_index)
        else:
            choice_index = (spectrum.shape[1]-num_ff)//self.num_sequences
            for i in range(self.num_sequences):
                list_of_choices.append(i*num_ff+choice_index)
        
        spectrum_blocks = []
        for choice in list_of_choices:
            sp = spectrum[:,choice:choice+num_ff]
            spectrum_blocks.append(sp)
        
        if video_metadata["audio_label"] == 'FAKE':
            label = 1
        else:
            label = 0

        video_original_filename = video_metadata["original"] if "original" in video_metadata else None
             
        return source_filename, spectrum_blocks, label, video_original_filename

In [15]:
ad = AudioDataset(data_dir,full_meta_dir, isBalanced=True, transform=None)
dl = DataLoader(ad, batch_size=16, shuffle= False,  num_workers= 8, collate_fn= ad.collate_fn, pin_memory= True, drop_last = False, worker_init_fn=ad.init_workers_fn)

In [17]:
it = iter(dl)

In [18]:
next(it)

(('rgmdqlpukt.mp4',
  'ddjpztlbqq.mp4',
  'cyrmhszohp.mp4',
  'rtwipuesde.mp4',
  'mhvjlwbcnw.mp4',
  'zvdnduoknx.mp4',
  'ebdblblfiy.mp4',
  'oqgezyhmau.mp4',
  'ivpufxwmth.mp4',
  'sowmxypmbs.mp4',
  'rvuspjfybf.mp4',
  'ciwzwmlyxz.mp4',
  'ogucjkrnsu.mp4',
  'jejzjwuumk.mp4',
  'ukqmdmrgoh.mp4',
  'giwwoaeuzi.mp4'),
 ([tensor([[-2.2240, -2.2240, -2.2240,  ..., -2.2240, -2.2240, -2.2240],
           [-2.2240, -2.2240, -2.2240,  ..., -2.2240, -2.2240, -2.2240],
           [ 0.1724, -0.0104,  0.0938,  ...,  0.4136,  0.5334,  0.5974],
           ...,
           [-1.0360, -1.2467, -1.3217,  ..., -0.2668, -0.0432, -0.3654],
           [-0.9134, -1.1194, -1.1739,  ..., -0.1626, -0.2167, -0.2583],
           [-1.2672, -1.0992, -1.1204,  ..., -0.1987, -0.3073, -0.5172]]),
   tensor([[-2.2240, -2.2240, -2.2240,  ..., -2.2240, -2.2240, -2.2240],
           [-2.2240, -2.2240, -2.2240,  ..., -2.2240, -2.2240, -2.2240],
           [ 0.6766,  0.6078,  0.6203,  ...,  0.2050,  0.3922, -0.2529],
    

In [13]:
sound, sample_rate = torchaudio.load("../dltraining/wavs/lsbttztgcp.wav")

In [35]:
IPython.display.Audio("../dltraining/wavs/itmnwcrluu.wav")

In [36]:
IPython.display.Audio("../dltraining/wavs/uibsxskypy.wav")