In [None]:
!pip install yt-dlp spleeter -q

In [None]:
import os
import re
import subprocess
import contextlib
import logging
from functools import wraps
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import librosa
import yt_dlp
from spleeter.separator import Separator
import ffmpeg

import torch
import torchaudio
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor

In [3]:
def suppress_output(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Redirect stdout and logging to os.devnull
        with open(os.devnull, 'w') as fnull:
            with contextlib.redirect_stdout(fnull), contextlib.redirect_stderr(fnull):
                # Configure logging to use a NullHandler
                null_handler = logging.NullHandler()
                root_logger = logging.getLogger()
                original_handlers = root_logger.handlers.copy()
                root_logger.handlers = [null_handler]
                try:
                    result = func(*args, **kwargs)
                finally:
                    # Restore the original logging handlers
                    root_logger.handlers = original_handlers
        return result
    return wrapper

In [4]:
def seconds_to_hh_mm_ss(total_seconds):
    hours = total_seconds // 3600
    minutes = (total_seconds % 3600) // 60
    seconds = total_seconds % 60
    return f"{hours:02}:{minutes:02}:{seconds:02}"

In [5]:
@suppress_output
def download_youtube_audio(url):
    ydl_opts = {
        'format': 'bestaudio',
        'format_sort': ['+size'],
        'outtmpl': '%(id)s.%(ext)s',
        'postprocessors': [{
            'key': 'FFmpegExtractAudio',
            'preferredcodec': 'wav',  
        }]
    }

    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        ydl.download([url])
        video_info = ydl.extract_info(url, download=False)
    
    video_path_local = Path(f"{video_info['id']}.wav")
    
    return video_path_local

In [6]:
@suppress_output
def slice_wav_by_length(input_file, outfile_name_pattern, desired_length=600):
    # FFmpeg command to split the WAV file into 10-minute segments
    command = [
        'ffmpeg',
        '-hide_banner',  # Added to hide FFmpeg banner
        '-loglevel', 'error',  # Set log level to error
        '-i', input_file,
        '-f', 'segment',
        '-segment_time', str(desired_length),
        '-c', 'copy',
        f"{outfile_name_pattern}%03d.wav"
    ]

    # Run the FFmpeg command
    subprocess.run(command)

In [7]:
@suppress_output
def seperate_vocal_and_accompanies(separator, output_paths):
    output_directory = 'output'
    os.makedirs(output_directory, exist_ok=True)
    
    for file_path in tqdm(output_paths, desc="Processing files"):
        output_file_path = os.path.join(output_directory, os.path.basename(file_path))
        separator.separate_to_file(file_path, output_file_path)

In [8]:
def get_filepaths_with_string_and_extension(
    root_directory='.', target_string='', extension=''):
    return sorted([os.path.abspath(os.path.join(root, file)) 
         for root, _, files in os.walk(root_directory)
         for file in files 
         if target_string in file and (not extension or file.endswith(extension))]
    )

In [12]:
def load_and_slice_audio(vocal_paths, slice_duration=10):
    all_samples = []
    all_sampling_rates = []

    for path in tqdm(vocal_paths):
        waveform, sampling_rate = torchaudio.load(path)
        waveform = waveform.squeeze().numpy()

        num_slices = waveform.shape[-1] // (sampling_rate * slice_duration)

        samples = [
            torch.from_numpy(
                waveform[:, i * sampling_rate * slice_duration : (i + 1) * sampling_rate * slice_duration]
            )
            for i in range(num_slices)
        ]

        all_samples.extend(samples)
        all_sampling_rates.extend([sampling_rate] * num_slices)

    return all_samples, all_sampling_rates

In [13]:
def get_common_sampling_rate(sampling_rates):
    if len(set(sampling_rates)) == 1:
        return int(sampling_rates[0])
    else:
        raise ValueError("All sampling rates should be the same.")

In [45]:
def predict_samples(model, feature_extractor, monofied_samples, sampling_rate, device):
    result = []

    for i, sample in tqdm(enumerate(monofied_samples), total=len(monofied_samples), desc="Processing samples"):
        inputs = feature_extractor(
            sample, feature_size=2, sampling_rate=sampling_rate, 
            padding="max_length", return_tensors="pt"
        )
        
        input_values = inputs.input_values.to(device)

        with torch.no_grad():
            outputs = model(input_values)
            predicted_labels = model.config.id2label[outputs.logits.argmax(-1).item()]
            result.append([i, predicted_labels])

    return result

In [49]:
def create_dataframe(predictions, time_window_length):
    df = pd.DataFrame(predictions, columns=['index', 'class'])
    df['start'] = df['index'].apply(lambda x: seconds_to_hh_mm_ss(x*time_window_length))
    df['end'] = df['index'].apply(lambda x: seconds_to_hh_mm_ss((x+1)*time_window_length))
    return df

In [88]:
def audio_analyzer(file_paths, model, device, slice_duration):
    sliced_samples, sampling_rates = load_and_slice_audio(file_paths, slice_duration)
    monofied_samples = [torch.mean(x, dim=0) for x in sliced_samples]
    sampling_rate = get_common_sampling_rate(sampling_rates)

    feature_extractor = ASTFeatureExtractor(
        sampling_rate=sampling_rate, do_normalize=True
    )

    predictions = predict_samples(
        model, feature_extractor, monofied_samples, sampling_rate, device
    )

    result_df = create_dataframe(predictions, time_window_length=slice_duration)

    return result_df

In [17]:
def class_manipulator(input_df):
    # Create a copy of the input DataFrame to avoid modifying the original
    df = input_df.copy()

    # Binarize the 'class' column
    df['binarized_class'] = df['class'].apply(
        lambda x: 0 if (x in ['Speech', 'Music', 'Tick', 'Clicking', 'Silence', 'Whistling']) else 1
    )

    # Create the 'group' column
    df['group'] = (df['binarized_class'] != df['binarized_class'].shift()).cumsum()

    # Group the DataFrame by the new 'group' column
    grouped_df = df.groupby(['binarized_class', 'group']).agg(
        start=('start', 'first'),
        end=('end', 'last'),
        class_value=('class', 'first')
    ).reset_index(drop=True)

    # Filter out unwanted classes and sort by 'start'
    intermediate_df = grouped_df[
        ~grouped_df['class_value'].isin(['Speech', 'Music', 'Tick', 'Clicking', 'Silence'])
    ].sort_values('start').reset_index(drop=True)

    return intermediate_df

In [18]:
def group_same_songs(input_df, interval_threshold, duration_threshold):
    
    df = input_df.copy()

    # Convert 'start' and 'end' columns to datetime
    df['start'] = pd.to_datetime(df['start'])
    df['end'] = pd.to_datetime(df['end'])

    # Create a new column 'group' and initialize with 0
    df['group'] = 0

    # Iterate through rows and update the 'group' column based on the specified threshold
    for i in range(1, len(df)):
        if (df['start'][i] - df['end'][i - 1]).total_seconds() <= interval_threshold:
            df.at[i, 'group'] = df.at[i - 1, 'group']
        else:
            df.at[i, 'group'] = df.at[i - 1, 'group'] + 1

    # Group by 'group' column
    grouped_df = df.groupby('group')

    # Iterate through groups and drop if the total duration is less than 30 seconds
    for group, group_df in grouped_df:
        duration = (group_df['end'].max() - group_df['start'].min()).total_seconds()
        if duration < duration_threshold:
            df = df[df['group'] != group]

    # Drop the 'class_value' column
    df = df.drop(columns=['class_value'])

    # Group by 'group' column and aggregate 'start' and 'end' columns
    df = df.groupby('group').agg({'start':'first', 'end':'last'})

    return df


In [91]:
!rm -rf ./output

output_filename_pattern = 'leona'
root_dir = '.'  
youtube_url = 'https://www.youtube.com/watch?v=hYXjZJPZxtU'

separator = Separator("spleeter:2stems")

slice_wav_by_length(download_youtube_audio(youtube_url), output_filename_pattern)

wav_slice_paths = get_filepaths_with_string_and_extension(
    root_dir, target_string = output_filename_pattern, extension='wav'
)

seperate_vocal_and_accompanies(separator, wav_slice_paths)

accompanies_paths = get_filepaths_with_string_and_extension(
    root_dir, target_string = 'accompaniment', extension='wav'
)

vocal_paths = get_filepaths_with_string_and_extension(
    root_dir, target_string = 'vocal', extension='wav'
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForAudioClassification.from_pretrained(
    "MIT/ast-finetuned-audioset-10-10-0.4593").to(device)

Processing files:   0%|          | 0/14 [00:00<?, ?it/s]

INFO:spleeter:File output/leona000.wav/leona000/vocals.wav written succesfully
INFO:spleeter:File output/leona000.wav/leona000/accompaniment.wav written succesfully
INFO:spleeter:File output/leona001.wav/leona001/vocals.wav written succesfully
INFO:spleeter:File output/leona001.wav/leona001/accompaniment.wav written succesfully
INFO:spleeter:File output/leona002.wav/leona002/vocals.wav written succesfully
INFO:spleeter:File output/leona002.wav/leona002/accompaniment.wav written succesfully
INFO:spleeter:File output/leona003.wav/leona003/vocals.wav written succesfully
INFO:spleeter:File output/leona003.wav/leona003/accompaniment.wav written succesfully
INFO:spleeter:File output/leona004.wav/leona004/vocals.wav written succesfully
INFO:spleeter:File output/leona004.wav/leona004/accompaniment.wav written succesfully
INFO:spleeter:File output/leona005.wav/leona005/vocals.wav written succesfully
INFO:spleeter:File output/leona005.wav/leona005/accompaniment.wav written succesfully
INFO:splee

In [92]:
result_vocal = audio_analyzer(vocal_paths, model, device, slice_duration=5)

  0%|          | 0/14 [00:00<?, ?it/s]

Processing samples:   0%|          | 0/1593 [00:00<?, ?it/s]

In [93]:
group_same_songs(
    class_manipulator(result_vocal),
    interval_threshold=30, duration_threshold=30
)

Unnamed: 0_level_0,start,end
group,Unnamed: 1_level_1,Unnamed: 2_level_1
1,2024-01-14 00:01:15,2024-01-14 00:02:50
3,2024-01-14 00:05:35,2024-01-14 00:09:05
4,2024-01-14 00:09:40,2024-01-14 00:14:20
6,2024-01-14 00:16:30,2024-01-14 00:21:05
8,2024-01-14 00:31:00,2024-01-14 00:32:10
9,2024-01-14 00:32:55,2024-01-14 00:34:45
10,2024-01-14 00:35:45,2024-01-14 00:40:50
11,2024-01-14 00:42:10,2024-01-14 00:45:35
12,2024-01-14 00:49:15,2024-01-14 00:52:00
13,2024-01-14 00:55:00,2024-01-14 00:58:45
