In [None]:
!pip install musdb
!pip install librosa
!pip install museval
!pip install wandb




In [None]:
import tensorflow as tf
import torch

# Check for GPU using TensorFlow
physical_devices = tf.config.list_physical_devices()
print("Available physical devices (TensorFlow):", physical_devices)

gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
    print("GPU is available (TensorFlow)")
else:
    print("GPU is not available (TensorFlow)")

# Check for GPU using PyTorch
if torch.cuda.is_available():
    print("GPU is available (PyTorch)")
    print("CUDA version:", torch.version.cuda)
    print("Number of GPUs:", torch.cuda.device_count())
    print("GPU name:", torch.cuda.get_device_name(0))
else:
    print("GPU is not available (PyTorch)")

# Check for TPU
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('TPU is available')
    print('TPU devices:', tpu.master())
except ValueError:
    print('TPU is not available')

# System information
print("System information:")
!lscpu

# GPU information
print("GPU information:")
!nvidia-smi

# Full hardware information
print("Full hardware information:")
!lshw -short



device_count = torch.cuda.device_count()
print(f"Number of available GPUs: {device_count}")

In [None]:
import os

print("Current working directory:", os.getcwd())
print("Contents of root directory:", os.listdir('/'))
print("Contents of /kaggle:", os.listdir('/kaggle'))
print("Contents of /kaggle/input:", os.listdir('/kaggle/input'))

# Try accessing your dataset directory again
dataset_path = '/kaggle/input/my-data'
if os.path.exists(dataset_path):
    print("Dataset path verified:", dataset_path)
    print("Contents of dataset path:", os.listdir(dataset_path))
else:
    print("Dataset path does not exist:", dataset_path)

In [None]:
import os

def list_directory_contents(path, level=0, max_items=3):
    if not os.path.exists(path):
        print(f"Path does not exist: {path}")
        return
    
    indent = '│   ' * level
    entries = os.listdir(path)
    entries.sort()

    dirs = [entry for entry in entries if os.path.isdir(os.path.join(path, entry))]
    files = [entry for entry in entries if os.path.isfile(os.path.join(path, entry))]

    # Print directories
    for i, d in enumerate(dirs):
        print(f"{indent}{'└──' if i == len(dirs) - 1 and not files else '├──'} {d}/")
        list_directory_contents(os.path.join(path, d), level + 1, max_items)

    # Print files, limit to max_items at the deepest level
    for j, f in enumerate(files[:max_items]):
        print(f"{indent}{'└──' if j == len(files[:max_items]) - 1 else '├──'} File: {f}")

list_directory_contents('/kaggle/input/audio-seperation-dataset')

In [None]:
# Cell 1: Imports, Configuration, and Logger Setup

import os
import sys
import logging
from logging.handlers import RotatingFileHandler
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import librosa
import musdb
from tqdm import tqdm

class ColoredFormatter(logging.Formatter):
    """Colored log formatter."""
    
    COLORS = {
        'DEBUG': '\033[0;36m',  # Cyan
        'INFO': '\033[0;32m',   # Green
        'WARNING': '\033[0;33m',  # Yellow
        'ERROR': '\033[0;31m',  # Red
        'CRITICAL': '\033[0;37;41m',  # White on Red
        'RESET': '\033[0m',  # Reset
    }

    def format(self, record):
        log_message = super().format(record)
        return f"{self.COLORS.get(record.levelname, self.COLORS['RESET'])}{log_message}{self.COLORS['RESET']}"

def setup_logger(name, log_file, level=logging.DEBUG):
    """Function to set up a logger with both colored console and file output"""
    
    logger = logging.getLogger(name)
    logger.setLevel(level)
    
    # Remove all existing handlers
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
    
    # Console handler with color
    c_handler = logging.StreamHandler(sys.stdout)
    c_formatter = ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    c_handler.setFormatter(c_formatter)
    c_handler.setLevel(level)
    
    # File handler (no color)
    f_handler = RotatingFileHandler(log_file, maxBytes=10000, backupCount=1)
    f_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    f_handler.setFormatter(f_formatter)
    f_handler.setLevel(level)
    
    logger.addHandler(c_handler)
    logger.addHandler(f_handler)
    
    return logger

# Set up the logger
logger = setup_logger('my_logger', 'app.log', level=logging.DEBUG)

class Config:
    SAMPLE_RATE = 44100
    CHUNK_DURATION = 7  # in seconds
    N_FFT = 2048
    HOP_LENGTH = 512
    HIDDEN_SIZE = 128
    NUM_LAYERS = 2
    BIDIRECTIONAL = True
    LEARNING_RATE = 1e-3
    BATCH_SIZE = 16
    NUM_EPOCHS = 50
    NUM_TRIALS = 7
    NUM_WORKERS = 4
    DROPOUT_RATE = 0.5
    WEIGHT_DECAY = 0.5
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    STEMS = ['drums', 'bass', 'other', 'vocals']  # 'mixture' is always included
    LOGGER = logger

    @classmethod
    def log_info(cls, message):
        cls.LOGGER.info(message)

# Log initial setup information
Config.log_info(f"Using device: {Config.DEVICE}")
Config.log_info("Initial setup complete.")

# Test the logger
Config.LOGGER.debug("This is a debug message")
Config.LOGGER.info("This is an info message")
Config.LOGGER.warning("This is a warning message")
Config.LOGGER.error("This is an error message")
Config.LOGGER.critical("This is a critical message")

# Explanation of libraries:
# - librosa: Used for audio processing tasks like STFT
# - torch: Main library for building and training neural networks
# - musdb: Provides easy access to the MUSDB18 dataset
# - tqdm: Used for progress bars in training loops

Config.log_info("All libraries imported and logger set up successfully.")

In [None]:
# Cell 2: Data Loading and Preprocessing

def load_musdb(root_dir):
    """Load the MUSDB dataset."""
    logger.info(f"Loading MUSDB dataset from {root_dir}")
    try:
        mus_train = musdb.DB(root_dir, subsets=['train'], is_wav=False)
        mus_test = musdb.DB(root_dir, subsets=['test'], is_wav=False)
        logger.info(f"Successfully loaded {len(mus_train)} train tracks and {len(mus_test)} test tracks from MUSDB")
        return mus_train, mus_test
    except Exception as e:
        logger.error(f"Failed to load MUSDB dataset: {str(e)}")
        raise

class CustomAudioMixer:
    def __init__(self, mus, target='vocals'):
        self.mus = mus
        self.target = target
        self.chunk_samples = int(Config.SAMPLE_RATE * Config.CHUNK_DURATION)
        logger.info("Initialized CustomAudioMixer")

    def get_random_chunk(self, track):
        track_duration = track.duration
        max_start = max(track_duration - Config.CHUNK_DURATION, 0)
        start_time = np.random.uniform(0, max_start)
        track.chunk_start = start_time
        track.chunk_duration = Config.CHUNK_DURATION

        # Get all stems
        stems = {stem: track.targets[stem].audio.T for stem in Config.STEMS}
        mixture = track.audio.T

        return mixture, stems

    def mix(self):
        track = np.random.choice(self.mus)
        mixture, stems = self.get_random_chunk(track)
        return mixture, stems[self.target]

class AudioTransforms:
    @staticmethod
    def to_magnitude_spectrogram(audio):
        stft = librosa.stft(audio, n_fft=Config.N_FFT, hop_length=Config.HOP_LENGTH)
        return np.abs(stft)

    @staticmethod
    def amplitude_to_db(spectrogram):
        return librosa.amplitude_to_db(spectrogram, ref=np.max)

    @staticmethod
    def normalize(spectrogram):
        epsilon = 1e-6
        return (spectrogram - spectrogram.mean()) / (spectrogram.std() + epsilon)

# Load dataset
musdb_root = '/kaggle/input/audio-seperation-dataset/musdb18'  # Update this path
mus_train, mus_test = load_musdb(musdb_root)

logger.info("Data loading and preprocessing functions defined.")

In [None]:
# Cell 3: Dataset Class
class SourceSeparationDataset(torch.utils.data.Dataset):
    def __init__(self, mus, target='vocals'):
        self.mus = mus
        self.mixer = CustomAudioMixer(mus, target)
        self.transforms = AudioTransforms()
        self.target = target
        logger.info(f"Initialized SourceSeparationDataset with {len(mus)} tracks")

    def __len__(self):
        return len(self.mus)

    def __getitem__(self, idx):
        mixture, target = self.mixer.mix()

        # Process mixture
        mixture_spec = self.transforms.to_magnitude_spectrogram(mixture)
        mixture_spec = self.transforms.amplitude_to_db(mixture_spec)
        mixture_spec = self.transforms.normalize(mixture_spec)

        # Process target
        target_spec = self.transforms.to_magnitude_spectrogram(target)
        target_spec = self.transforms.amplitude_to_db(target_spec)
        target_spec = self.transforms.normalize(target_spec)

         # Only log every 20th sample
        if idx % 20 == 0:
            logger.debug(f"Mixture spec shape: {mixture_spec.shape}, Target spec shape: {target_spec.shape}")

        # Ensure the output is 3D (freq, time, channels)
        if mixture_spec.ndim == 2:
            mixture_spec = mixture_spec[..., np.newaxis]
            target_spec = target_spec[..., np.newaxis]

        return torch.FloatTensor(mixture_spec), torch.FloatTensor(target_spec)

# Create dataset instances
train_dataset = SourceSeparationDataset(mus_train)
test_dataset = SourceSeparationDataset(mus_test)

# If you want to create a validation set from the train set:
train_size = int(0.8 * len(mus_train))
val_size = len(mus_train) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

logger.info(f"Datasets created. Train: {len(train_dataset)}, Validation: {len(val_dataset)}, Test: {len(test_dataset)} tracks")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_spectrogram(spectrogram, title):
    plt.figure(figsize=(10, 4))
    if np.all(np.isnan(spectrogram)) or np.all(np.isinf(spectrogram)):
        plt.text(0.5, 0.5, f"Invalid Spectrogram: contains NaN or Inf", ha='center', va='center')
    elif np.all(spectrogram == spectrogram[0, 0]):
        plt.text(0.5, 0.5, f"Constant Spectrogram: all values are {spectrogram[0, 0]}", ha='center', va='center')
    else:
        plt.imshow(spectrogram, aspect='auto', origin='lower', interpolation='nearest')
    plt.title(title)
    plt.colorbar(format='%+2.0f dB')
    plt.tight_layout()

def visualize_example(dataset):
    # Get a random sample from the dataset
    idx = np.random.randint(len(dataset))
    mixture_spec, target_spec = dataset[idx]
    
    # Convert from torch tensor to numpy array
    mixture_spec = mixture_spec.numpy()
    target_spec = target_spec.numpy()
    
    # If the spectrograms are 3D (stereo), take the first channel
    if mixture_spec.ndim == 3:
        mixture_spec = mixture_spec[0]
    if target_spec.ndim == 3:
        target_spec = target_spec[0]
    
    print(f"Mixture spectrogram shape: {mixture_spec.shape}")
    print(f"Mixture spectrogram min: {mixture_spec.min()}, max: {mixture_spec.max()}")
    print(f"Target spectrogram shape: {target_spec.shape}")
    print(f"Target spectrogram min: {target_spec.min()}, max: {target_spec.max()}")
    
    # Plot spectrograms
    plt.figure(figsize=(15, 8))
    plt.subplot(2, 1, 1)
    plot_spectrogram(mixture_spec, 'Mixture Spectrogram')
    plt.subplot(2, 1, 2)
    plot_spectrogram(target_spec, 'Target (Vocals) Spectrogram')

    plt.tight_layout()
    plt.show()

# Visualize a single random example
logger.info("Visualizing a random example from the dataset")
visualize_example(train_dataset)

# Visualize the distribution of spectrogram values
def plot_spectrogram_distribution(dataset, num_samples=1):
    spec_values = []
    for _ in range(num_samples):
        mixture, target = dataset[np.random.randint(len(dataset))]
        spec_values.extend(mixture.numpy().flatten())
        spec_values.extend(target.numpy().flatten())

    plt.figure(figsize=(10, 4))
    plt.hist(spec_values, bins=50, density=True)
    plt.title('Distribution of Spectrogram Values')
    plt.xlabel('Value')
    plt.ylabel('Density')
    plt.tight_layout()
    plt.show()

logger.info("Plotting distribution of spectrogram values")
plot_spectrogram_distribution(train_dataset)

logger.info("Dataset visualization complete.")


In [None]:
# Cell 4: DataLoader Setup

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)

logger.info("DataLoaders created.")

In [None]:
class MaskInference(nn.Module):
    def __init__(self):
        super(MaskInference, self).__init__()
        self.lstm = nn.LSTM(input_size=(Config.N_FFT // 2 + 1) * 2, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True)
        self.mask = nn.Linear(256, (Config.N_FFT // 2 + 1) * 2)  # Adjusted output size

    def forward(self, x):
        batch_size, channels, freq_bins, time_steps = x.size()

        # Check if the input dimensions are correct
        expected_input_size = (Config.N_FFT // 2 + 1) * 2
        actual_input_size = freq_bins * channels

        if actual_input_size != expected_input_size:
            raise ValueError(f"Expected input size {expected_input_size} but got {actual_input_size}")

        # Reshape for LSTM input
        x = x.reshape(batch_size, time_steps, -1)
        Config.LOGGER.debug(f"Input shape after reshape: {x.shape}")

        x, _ = self.lstm(x)
        Config.LOGGER.debug(f"Output from LSTM shape: {x.shape}")

        # Apply mask and reshape back
        mask = self.mask(x)
        Config.LOGGER.debug(f"Mask shape before final reshape: {mask.shape}")

        # Reshape to [batch_size, channels, freq_bins, time_steps]
        mask = mask.reshape(batch_size, channels, freq_bins, time_steps)
        return mask


logger.info("Model definition complete.")


In [None]:
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import wandb
import numpy as np

# Set up wandb key
wandb.login(key='af2263417a18c3258bde417967bf3c686bfa569d')

# Initialize Wandb project
wandb.init(project="audio_source_thesis")

# Declare global lists to store metrics
train_losses = []
val_losses = []
train_sdrs = []
train_sirs = []
train_sars = []

def calculate_sdr(reference_signal, estimated_signal):
    noise = reference_signal - estimated_signal
    sdr = 10 * torch.log10(torch.sum(reference_signal ** 2) / torch.sum(noise ** 2))
    return sdr.item()

def calculate_sir(reference_signal, estimated_signal, interference_signal):
    interference = estimated_signal - reference_signal
    sir = 10 * torch.log10(torch.sum(reference_signal ** 2) / torch.sum(interference_signal ** 2))
    return sir.item()

def calculate_sar(reference_signal, estimated_signal, artifacts):
    artifacts = estimated_signal - reference_signal
    sar = 10 * torch.log10(torch.sum(reference_signal ** 2) / torch.sum(artifacts ** 2))
    return sar.item()

def objective(trial):
    global train_losses, val_losses, train_sdrs, train_sirs, train_sars
    
    # Suggest hyperparameters
    Config.LEARNING_RATE = trial.suggest_loguniform('lr', 1e-5, 1e-2)
    Config.BATCH_SIZE = trial.suggest_categorical('batch_size', [16, 32, 64])
    Config.HIDDEN_SIZE = trial.suggest_categorical('hidden_size', [64, 128, 256])
    Config.NUM_LAYERS = trial.suggest_int('num_layers', 1, 4)
    dropout_rate = trial.suggest_uniform('dropout_rate', 0.1, 0.5)
    weight_decay = trial.suggest_loguniform('weight_decay', 1e-6, 1e-3)

    # Initialize the MaskInference model with suggested hyperparameters
    model = MaskInference().to(Config.DEVICE)
    
    # Log model architecture
    wandb.watch(model)

    # Optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=weight_decay)
    criterion = nn.MSELoss()

    # Learning rate scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=Config.NUM_WORKERS)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=Config.NUM_WORKERS)

    best_loss = float('inf')

    for epoch in range(Config.NUM_EPOCHS):
        model.train()
        train_loss = 0.0

        for batch in train_loader:
            optimizer.zero_grad()

            inputs, targets = batch
            inputs, targets = inputs.to(Config.DEVICE), targets.to(Config.DEVICE)
            
            output = model(inputs)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        train_losses.append(train_loss)  # Store training loss
        wandb.log({"Training Loss": train_loss, "Epoch": epoch})

        # Validation loop
        val_loss = 0.0
        sdr_list, sir_list, sar_list = [], [], []
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                inputs, targets = batch
                inputs, targets = inputs.to(Config.DEVICE), targets.to(Config.DEVICE)

                output = model(inputs)
                loss = criterion(output, targets)
                val_loss += loss.item()

                # Calculate SDR, SIR, SAR
                sdr = calculate_sdr(targets, output)
                sir = calculate_sir(targets, output, interference_signal=inputs)
                sar = calculate_sar(targets, output, artifacts=output - targets)
                sdr_list.append(sdr)
                sir_list.append(sir)
                sar_list.append(sar)

        val_loss /= len(val_loader)
        val_losses.append(val_loss)  # Store validation loss
        scheduler.step(val_loss)

        # Store validation metrics
        mean_sdr = np.mean(sdr_list)
        mean_sir = np.mean(sir_list)
        mean_sar = np.mean(sar_list)
        train_sdrs.append(mean_sdr)
        train_sirs.append(mean_sir)
        train_sars.append(mean_sar)

        wandb.log({"Validation Loss": val_loss, "Validation SDR": mean_sdr, "Validation SIR": mean_sir, "Validation SAR": mean_sar, "Epoch": epoch})
        Config.LOGGER.info(f'Epoch {epoch + 1}/{Config.NUM_EPOCHS} - Training Loss: {train_loss:.4f} - Validation Loss: {val_loss:.4f} - SDR: {mean_sdr:.4f} - SIR: {mean_sir:.4f} - SAR: {mean_sar:.4f}')

        # Early stopping with Optuna
        trial.report(val_loss, epoch)
        if trial.should_prune():
            Config.LOGGER.info('Trial pruned at epoch {}'.format(epoch + 1))
            raise optuna.exceptions.TrialPruned()

        # Save the best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            Config.LOGGER.info(f'Best model saved at epoch {epoch + 1} with loss {best_loss:.4f}')

    # Return only the validation loss to Optuna
    return val_loss

# Running the optimization
study = optuna.create_study(direction='minimize')
study.optimize(objective, Config.NUM_TRIALS)



In [None]:
import torch
import yaml
from torch.utils.data import DataLoader, ConcatDataset

# Load the best hyperparameters from Optuna
best_params = study.best_trial.params

# Combine the training and validation datasets
full_dataset = ConcatDataset([train_dataset, val_dataset])
full_loader = DataLoader(full_dataset, batch_size=best_params['batch_size'], shuffle=True, num_workers=Config.NUM_WORKERS)

# Initialize the model with the best hyperparameters
model = MaskInference().to(Config.DEVICE)
optimizer = optim.Adam(model.parameters(), lr=best_params['lr'], weight_decay=best_params['weight_decay'])
criterion = nn.MSELoss()

# Retrain the model
for epoch in range(Config.NUM_EPOCHS):
    model.train()
    running_loss = 0.0

    for batch in full_loader:
        optimizer.zero_grad()
        inputs, targets = batch
        inputs, targets = inputs.to(Config.DEVICE), targets.to(Config.DEVICE)

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(full_loader)
    Config.LOGGER.info(f'Epoch {epoch + 1}/{Config.NUM_EPOCHS} - Loss: {avg_loss:.4f}')

# Calculate final metrics on the full dataset
model.eval()
final_loss = avg_loss  # Final training loss
final_sdr_list, final_sir_list, final_sar_list = [], [], []

with torch.no_grad():
    for batch in full_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(Config.DEVICE), targets.to(Config.DEVICE)
        outputs = model(inputs)

        sdr = calculate_sdr(targets, outputs)
        sir = calculate_sir(targets, outputs, interference_signal=inputs)
        sar = calculate_sar(targets, outputs, artifacts=outputs - targets)

        final_sdr_list.append(sdr)
        final_sir_list.append(sir)
        final_sar_list.append(sar)

# Compute mean metrics
final_sdr = np.mean(final_sdr_list)
final_sir = np.mean(final_sir_list)
final_sar = np.mean(final_sar_list)

# Save the model
model_save_path = 'final_model.pth'
torch.save(model.state_dict(), model_save_path)
Config.LOGGER.info(f"Model saved to {model_save_path}")

# Save the best hyperparameters
hyperparams_save_path = 'best_hyperparameters.yaml'
with open(hyperparams_save_path, 'w') as f:
    yaml.dump(best_params, f)
Config.LOGGER.info(f"Hyperparameters saved to {hyperparams_save_path}")

# Save final metrics and loss
metrics_save_path = 'final_metrics.yaml'
final_metrics = {
    'final_loss': float(final_loss),  # Already a Python float
    'final_sdr': float(final_sdr),    # Convert numpy.float64 to Python float
    'final_sir': float(final_sir),    # Convert numpy.float64 to Python float
    'final_sar': float(final_sar)     # Convert numpy.float64 to Python float
}
with open(metrics_save_path, 'w') as f:
    yaml.dump(final_metrics, f)
Config.LOGGER.info(f"Final metrics saved to {metrics_save_path}")


In [None]:
import torch
import yaml
from torch.utils.data import DataLoader

# Load the test dataset
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=Config.NUM_WORKERS)

# Load the trained model
model = MaskInference().to(Config.DEVICE)
model.load_state_dict(torch.load('final_model.pth'))
model.eval()

# Initialize lists to store the metrics
test_loss = 0.0
sdr_list, sir_list, sar_list = [], [], []

criterion = nn.MSELoss()

# Evaluate the model on the test set
with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(Config.DEVICE), targets.to(Config.DEVICE)

        # Generate predictions
        outputs = model(inputs)

        # Calculate loss
        loss = criterion(outputs, targets)
        test_loss += loss.item()

        # Calculate metrics
        sdr = calculate_sdr(targets, outputs)
        sir = calculate_sir(targets, outputs, interference_signal=inputs)
        sar = calculate_sar(targets, outputs, artifacts=outputs - targets)

        # Append metrics to lists
        sdr_list.append(sdr)
        sir_list.append(sir)
        sar_list.append(sar)

# Calculate average metrics
test_loss /= len(test_loader)
mean_sdr = np.mean(sdr_list)
mean_sir = np.mean(sir_list)
mean_sar = np.mean(sar_list)

# Log the results
Config.LOGGER.info(f"Test Loss: {test_loss:.4f}")
Config.LOGGER.info(f"Test SDR: {mean_sdr:.4f}")
Config.LOGGER.info(f"Test SIR: {mean_sir:.4f}")
Config.LOGGER.info(f"Test SAR: {mean_sar:.4f}")

# Save test metrics
test_metrics_save_path = 'test_metrics.yaml'
test_metrics = {
    'test_loss': float(test_loss),
    'test_sdr': float(mean_sdr),
    'test_sir': float(mean_sir),
    'test_sar':float( mean_sar)
}
with open(test_metrics_save_path, 'w') as f:
    yaml.dump(test_metrics, f)
Config.LOGGER.info(f"Test metrics saved to {test_metrics_save_path}")


In [None]:
import torch
import numpy as np
import scipy.signal
import librosa

# Define post-processing functions

def refine_mask(mask, method='wiener', size=3):
    """
    Apply a post-processing refinement to the mask.

    Args:
    - mask (torch.Tensor): The predicted mask.
    - method (str): The post-processing method. Options: 'median', 'wiener'.
    - size (int): Size of the filter.

    Returns:
    - refined_mask (torch.Tensor): The refined mask.
    """
    mask_np = mask.cpu().numpy()  # Convert to numpy for processing

    if method == 'median':
        refined_mask = scipy.signal.medfilt(mask_np, kernel_size=size)
    elif method == 'wiener':
        refined_mask = scipy.signal.wiener(mask_np, mysize=size)
    else:
        raise ValueError("Unsupported method. Choose 'median' or 'wiener'.")

    return torch.tensor(refined_mask, device=mask.device)

def apply_dynamic_range_compression(signal, threshold=0.5, ratio=4.0):
    """
    Apply dynamic range compression to the separated source.

    Args:
    - signal (torch.Tensor): The audio signal.
    - threshold (float): Threshold level above which compression is applied.
    - ratio (float): Compression ratio.

    Returns:
    - compressed_signal (torch.Tensor): The compressed audio signal.
    """
    signal_np = signal.cpu().numpy()

    # Apply dynamic range compression
    signal_compressed = np.where(signal_np > threshold,
                                 threshold + (signal_np - threshold) / ratio,
                                 signal_np)

    return torch.tensor(signal_compressed, device=signal.device)

# Load the trained model
model = MaskInference().to(Config.DEVICE)
model.load_state_dict(torch.load('final_model.pth'))
model.eval()

# Define a DataLoader for the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=Config.NUM_WORKERS)

# Lists to store post-processing results
test_sdrs = []
test_sirs = []
test_sars = []

# Iterate over the test set and apply post-processing
with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(Config.DEVICE), targets.to(Config.DEVICE)

        # Model inference
        output = model(inputs)

        # Apply post-processing techniques
        refined_mask = refine_mask(output, method='wiener', size=3)
        output_refined = apply_dynamic_range_compression(refined_mask)

        # Calculate SDR, SIR, SAR
        sdr = calculate_sdr(targets, output_refined)
        sir = calculate_sir(targets, output_refined, interference_signal=inputs)
        sar = calculate_sar(targets, output_refined, artifacts=output_refined - targets)
        
        # Store results
        test_sdrs.append(sdr)
        test_sirs.append(sir)
        test_sars.append(sar)

        Config.LOGGER.info(f"Processed batch - SDR: {sdr:.4f}, SIR: {sir:.4f}, SAR: {sar:.4f}")

# Compute mean values for test metrics
mean_test_sdr = np.mean(test_sdrs)
mean_test_sir = np.mean(test_sirs)
mean_test_sar = np.mean(test_sars)

# Log final test metrics
Config.LOGGER.info(f"Final Test Metrics - SDR: {mean_test_sdr:.4f}, SIR: {mean_test_sir:.4f}, SAR: {mean_test_sar:.4f}")

# Optionally, log these results to Weights & Biases
wandb.log({
    "Test SDR": mean_test_sdr,
    "Test SIR": mean_test_sir,
    "Test SAR": mean_test_sar
})

print("Post-processing and evaluation complete.")


In [None]:
import librosa
import torch
import numpy as np
import scipy.ndimage
import pywt
from scipy.io.wavfile import write
from scipy import signal

# High-Pass Filter using scipy
def apply_high_pass_filter(audio, sr, cutoff=100):
    sos = signal.iirfilter(2, cutoff, btype='highpass', ftype='butter', fs=sr, output='sos')
    filtered_audio = signal.sosfilt(sos, audio)
    return filtered_audio

# Dynamic Range Compression
def apply_compression(audio, threshold=-20, ratio=4):
    compressor = librosa.effects.preemphasis(audio, coef=threshold)
    return compressor

# Noise Reduction using Spectral Gating
def reduce_noise(audio, sr, noise_reduction_factor=1.5, n_fft=2048, hop_length=512, n_std_thresh=1.5):
    """
    Reduces noise from audio using spectral gating.
    
    Parameters:
    - audio: np.array, the input audio signal.
    - sr: int, the sample rate of the audio.
    - noise_reduction_factor: float, factor to determine noise threshold.
    - n_fft: int, FFT window size.
    - hop_length: int, hop length for STFT.
    
    Returns:
    - reduced_audio: np.array, the denoised audio signal.
    """
    # Step 1: Compute the STFT of the audio
    stft = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    magnitude, phase = np.abs(stft), np.angle(stft)
    
    # Step 2: Estimate noise from the quietest frames
    noise_estimate = np.mean(magnitude, axis=1, keepdims=True)

    # Step 3: Create a noise threshold
    noise_thresh = noise_estimate * noise_reduction_factor
    
    # Step 4: Apply spectral gating
    mask = magnitude >= noise_thresh
    gated_magnitude = magnitude * mask
    
    # Step 5: Reconstruct the denoised audio using inverse STFT
    stft_denoised = gated_magnitude * np.exp(1j * phase)
    reduced_audio = librosa.istft(stft_denoised, hop_length=hop_length)
    
    return reduced_audio

# Spectrogram Augmentation (time masking)
def time_masking(spec, T=20):
    time_mask = np.random.randint(0, spec.shape[1] - T)
    spec[:, time_mask:time_mask + T] = 0
    return spec

# Smoothing Mask
def smooth_mask(mask, kernel_size=3):
    smoothed_mask = scipy.ndimage.median_filter(mask.cpu().numpy(), size=kernel_size)
    return torch.tensor(smoothed_mask, device=mask.device)

# Wavelet Denoising
def wavelet_denoise(audio):
    """
    Apply wavelet denoising to the audio signal.

    Parameters:
    - audio (np.ndarray): The input audio signal.

    Returns:
    - denoised_audio (np.ndarray): The denoised audio signal.
    """
    coeffs = pywt.wavedec(audio, 'db1', level=2)
    threshold = np.median(np.abs(coeffs[-1])) / 0.6745
    denoised_coeffs = [pywt.threshold(c, value=threshold, mode='soft') for c in coeffs]
    denoised_audio = pywt.waverec(denoised_coeffs, 'db1')
    return denoised_audio

# Dynamic Range Expansion
def apply_dynamic_range_expansion(audio, expansion_ratio=2):
    expanded_audio = audio ** expansion_ratio
    return expanded_audio

# Load MP3 function
def load_mp3(file_path, target_sample_rate=44100):
    """Load an MP3 file and convert it to the target sample rate."""
    y, sr = librosa.load(file_path, sr=target_sample_rate, mono=False)
    return y, sr

# Preprocessing Function
def preprocess_audio(y, sr):
    """Preprocess the audio data to prepare it for the model."""
    y = apply_high_pass_filter(y, sr)
    y = apply_compression(y)
    y = reduce_noise(y, sr)

    if y.ndim == 2:
        y = librosa.to_mono(y)

    stft = librosa.stft(y, n_fft=2048, hop_length=512)
    magnitude = np.abs(stft)
    phase = np.angle(stft)

    magnitude_normalized = (magnitude - magnitude.mean()) / (magnitude.std() + 1e-6)
    magnitude_normalized = time_masking(magnitude_normalized)  # Optional augmentation

    # Double the frequency bins to match the model's expected input size
    magnitude_doubled = np.concatenate([magnitude_normalized, magnitude_normalized], axis=0)

    magnitude_tensor = torch.tensor(magnitude_doubled, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    
    return magnitude_tensor, phase

# Post-Processing and Reconstruction
def reconstruct_audio(separated_magnitude, phase):
    """Reconstruct the audio waveform from the modified spectrogram."""
    phase_doubled = np.concatenate([phase, phase], axis=0)
    complex_stft = separated_magnitude * np.exp(1j * phase_doubled)
    y_reconstructed = librosa.istft(complex_stft, hop_length=512)
    
    # Apply wavelet denoising directly to the numpy array
    y_reconstructed = wavelet_denoise(y_reconstructed)
    y_reconstructed = apply_dynamic_range_expansion(y_reconstructed)
    
    return y_reconstructed

# Save the audio file
def save_audio(file_path, audio_data, sample_rate):
    """Save the processed audio to a file."""
    write(file_path, sample_rate, (audio_data * 32767).astype(np.int16))

# Remove vocals function (this is the core function you need to define)
def remove_vocals(model, magnitude_tensor):
    """
    Use the trained model to remove vocals from the input spectrogram.
    
    Parameters:
    - model: The trained PyTorch model.
    - magnitude_tensor: The input magnitude spectrogram as a tensor.
    
    Returns:
    - mask: The vocal mask predicted by the model.
    """
    with torch.no_grad():
        mask = model(magnitude_tensor)
    return mask

# Full Processing Pipeline
def process_audio(file_path, output_path, model):
    y, sr = load_mp3(file_path)

    # Preprocess the audio
    magnitude_tensor, phase = preprocess_audio(y, sr)

    # Ensure the magnitude tensor is on the same device as the model
    magnitude_tensor = magnitude_tensor.to(Config.DEVICE)

    # Remove vocals using the trained model
    mask = remove_vocals(model, magnitude_tensor)

    # Smooth the mask
    smoothed_mask = smooth_mask(mask)

    # Apply the mask to the magnitude spectrogram
    separated_magnitude = magnitude_tensor.squeeze().cpu().numpy() * (1 - smoothed_mask.cpu().numpy())

    # Reconstruct the waveform from the modified magnitude and original phase
    reconstructed_waveform = reconstruct_audio(separated_magnitude, phase)

    # Save the result as a WAV file
    save_audio(output_path, reconstructed_waveform, sr)
    print(f"Vocal removal complete. The output audio is saved as '{output_path}'.")

# Example Usage
file_path = '/kaggle/input/example-song/Corey Taylor - Snuff (Acoustic).mp3'
output_path = '/kaggle/working/example_song_no_vocals.wav'
process_audio(file_path, output_path, model)


In [None]:
!pip install onnxruntime shap


In [None]:
# Assuming you have a test dataset loader ready
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Get a single batch (mixture and target)
for batch in test_loader:
    mixture, target = batch
    mixture = mixture.to(Config.DEVICE)
    target = target.to(Config.DEVICE)
    break  # We only need one example for explanation

# Now you have `mixture` and `target` defined


In [None]:
import torch
import matplotlib.pyplot as plt

# Define the function to compute saliency maps
def compute_saliency_maps(input_data, target_class):
    # Ensure the model is in training mode to allow backward pass
    model.train()

    # Ensure the input requires gradients
    input_data.requires_grad_()

    # Forward pass
    output = model(input_data)
    
    # Calculate the loss
    loss = criterion(output, target_class)
    
    # Backward pass to compute gradients
    loss.backward()
    
    # Get the absolute value of the gradients
    saliency = input_data.grad.abs().detach().cpu().numpy()
    
    # Switch the model back to evaluation mode
    model.eval()
    
    return saliency


def plot_saliency_map(saliency_map, title="Saliency Map", save_path=None):
    # Assuming saliency_map is 4D: (batch_size, channels, freq_bins, time_steps)
    # Reduce it to 2D: (freq_bins, time_steps)
    
    # If there's only one batch and one channel, you can squeeze it to 2D
    saliency_map = saliency_map.squeeze()  # This will reduce the dimensions
    
    if saliency_map.ndim == 3:  # If there are still 3 dimensions (e.g., channels remain)
        saliency_map = saliency_map[0]  # Select the first channel or average/sum across channels

    plt.figure(figsize=(8, 6))
    plt.imshow(saliency_map, cmap='inferno', aspect='auto')
    plt.colorbar(label='Saliency Intensity')
    plt.title(title, fontsize=18, fontweight='bold')
    plt.xlabel('Time Frames', fontsize=14)
    plt.ylabel('Frequency Bins', fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    else:
        plt.show()


       
    # Assuming mixture and target are defined tensors
saliency_map = compute_saliency_maps(mixture.to(Config.DEVICE), target.to(Config.DEVICE))

# Plot the saliency map
plot_saliency_map(saliency_map, title="Saliency Map")



In [None]:
import matplotlib.pyplot as plt
import numpy as np
import yaml

# Load final metrics from YAML file
with open('final_metrics.yaml', 'r') as f:
    final_metrics = yaml.safe_load(f)

# Load test metrics from YAML file
with open('test_metrics.yaml', 'r') as f:
    test_metrics = yaml.safe_load(f)

# Extract final (training) metrics
final_sdr = final_metrics.get('final_sdr', 0)
final_sir = final_metrics.get('final_sir', 0)
final_sar = final_metrics.get('final_sar', 0)
final_loss = final_metrics.get('final_loss', 0)

# Extract test metrics
test_sdr = test_metrics.get('test_sdr', 0)
test_sir = test_metrics.get('test_sir', 0)
test_sar = test_metrics.get('test_sar', 0)
test_loss = test_metrics.get('test_loss', 0)

# Prepare lists for plotting
final_val_metrics = [final_sdr, final_sir, final_sar]
final_test_metrics = [test_sdr, test_sir, test_sar]
metrics_labels = ["SDR", "SIR", "SAR"]

# Plot comparison of SDR, SIR, and SAR between final (training) and test
def plot_metric_comparison(val_metrics, test_metrics, labels):
    x = np.arange(len(labels))  # Label locations
    width = 0.35  # Width of the bars

    fig, ax = plt.subplots(figsize=(10, 6))
    
    rects1 = ax.bar(x - width/2, val_metrics, width, label='Final (Training)', color='dodgerblue')
    rects2 = ax.bar(x + width/2, test_metrics, width, label='Test', color='orange')

    # Add some text for labels, title, and custom x-axis tick labels, etc.
    ax.set_ylabel('Metric (dB)', fontsize=14)
    ax.set_title('Comparison of Final (Training) vs Test Metrics', fontsize=16, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(labels, fontsize=12)
    ax.legend(fontsize=12)

    # Add value labels on top of bars
    def add_labels(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(f'{height:.2f}',
                        xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3),  # 3 points vertical offset
                        textcoords="offset points",
                        ha='center', fontsize=12)

    add_labels(rects1)
    add_labels(rects2)

    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.show()

# Call the plotting function
plot_metric_comparison(final_val_metrics, final_test_metrics, metrics_labels)

# Plot loss comparison
def plot_loss_comparison(final_loss, test_loss):
    losses = [final_loss, test_loss]
    labels = ['Final (Training) Loss', 'Test Loss']
    
    plt.figure(figsize=(8, 5))
    bars = plt.bar(labels, losses, color=['dodgerblue', 'orange'], edgecolor='black')
    
    # Add value labels on top of bars
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval + 0.02, f'{yval:.2f}', ha='center', fontsize=12)

    plt.ylabel('Loss', fontsize=14)
    plt.title('Final (Training) vs Test Loss', fontsize=16, fontweight='bold')
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.show()

# Call the loss comparison plotting function
plot_loss_comparison(final_loss, test_loss)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Visualization Functions
def plot_loss_curves(train_losses, val_losses):
    epochs = np.arange(1, len(train_losses) + 1)
    
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    plt.title('Training and Validation Loss Over Epochs', fontsize=16, fontweight='bold')
    plt.xlabel('Epochs', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.legend(fontsize=12)
    
    # Force the x-axis to display only whole numbers
    plt.xticks(ticks=epochs)
    
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.show()

# Example usage
plot_loss_curves(train_losses, val_losses)


In [None]:
import torch
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

# Function to convert the PyTorch model to ONNX format
def convert_to_onnx_simple(model, input_tensor, output_path):
    """
    Convert the PyTorch model to ONNX format.
    
    Parameters:
    - model: The trained PyTorch model.
    - input_tensor: A sample input tensor for tracing the model.
    - output_path: The file path to save the ONNX model.
    
    Returns:
    - None
    """
    torch.onnx.export(
        model,
        input_tensor,
        output_path,
        export_params=True,
        opset_version=12,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print(f"Model has been converted to ONNX and saved to {output_path}")

# Function to quantize the ONNX model
def quantize_onnx_model_simple(onnx_model_path, quantized_model_path):
    """
    Quantize the ONNX model to reduce size and improve inference speed.
    
    Parameters:
    - onnx_model_path: Path to the original ONNX model.
    - quantized_model_path: Path to save the quantized ONNX model.
    
    Returns:
    - None
    """
    quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8)
    print(f"Quantized model saved to {quantized_model_path}")

# Example usage of the full process
def full_process_simple(model, input_size, device):
    """
    Run the full process including ONNX conversion and quantization.
    
    Parameters:
    - model: The trained PyTorch model.
    - input_size: A tuple representing the size of the input tensor.
    - device: The device to which the tensors should be moved (e.g., 'cpu' or 'cuda').
    
    Returns:
    - None
    """
    # Move the model to the correct device
    model.to(device)
    
    # Create a dummy input tensor with the correct dimensions for tracing
    batch_size = 1
    channels = 1
    dummy_input = torch.randn(batch_size, channels, *input_size).to(device)

    # Define paths
    onnx_model_path = 'model.onnx'
    quantized_model_path = 'model_quantized.onnx'

    # Convert the model to ONNX format
    convert_to_onnx_simple(model, dummy_input, onnx_model_path)

    # Quantize the ONNX model
    quantize_onnx_model_simple(onnx_model_path, quantized_model_path)

# Assuming you have a model and input size ready
input_size = (2050, 23482)  # Frequency bins x Time steps
model.eval()  # Set model to evaluation mode

# Specify the device (e.g., 'cuda' for GPU, 'cpu' for CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Run the full process
full_process_simple(model, input_size, device)
