In [6]:
import os
import torch
import torchvision.transforms as transforms
import torchaudio.transforms as T
import torchvision.io as io
import torchaudio
import numpy as np
from google.colab import drive
drive.mount('/content/drive', force_remount = True)

Mounted at /content/drive


In [3]:

# Define augmentation functions
def time_stretch(audio, factor):
    return torchaudio.transforms.TimeStretch(n_stps=factor)(audio)

def pitch_shift(audio, shift):
    return torchaudio.transforms.PitchShift(sample_rate=audio.size(1), n_steps=shift)(audio)

def noise_injection(audio, noise_level):
    noise = torch.randn_like(audio) * noise_level
    return audio + noise

def frequency_masking(spectrogram, num_masks=2, mask_factor=27):
    masked_spectrogram = spectrogram.clone()
    for _ in range(num_masks):
        f = torch.randint(low=0, high=mask_factor, size=(1,))
        f_max = min(f + mask_factor, spectrogram.size(1))
        masked_spectrogram[:, f:f_max] = 0
    return masked_spectrogram

def dynamic_range_compression(audio, factor):
    return torchaudio.transforms.Vol(factor)(audio)

def time_warp(audio, warp_factor):
    return torchaudio.transforms.TimeWarp(sample_rate=audio.size(1), warp_param=warp_factor)(audio)


In [4]:

# Define augmentation parameters
augmentation_params = {
    "time_stretch_factor": 1.1,
    "pitch_shift_amount": 3,
    "noise_level": 0.1,
    "time_warp_factor": 0.2,
    "freq_masking": 2,
    "freq_mask_width": 15
}

# Define directory paths
data_dir = '/content/drive/My Drive/DLproject'
output_dir = '/content/drive/My Drive/DLproject/train_augmented'

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Load the class distribution
class_distribution = {
    "dog_barking": 640,
    "car_horn": 344,
    "Fart": 291,
    "Guitar": 548,
    "drilling": 560,
    "Gunshot_and_gunfire": 448,
    "Hi-hat": 171,
    "Knock": 168,
    "Splash_and_splatter": 174,
    "Snare_drum": 449,
    "Shatter": 212,
    "Laughter": 295,
    "siren": 560
}

# Define threshold for underrepresented classes
threshold = max(class_distribution.values())  # You can adjust this threshold based on your dataset


In [8]:
# Iterate over the class distribution
for class_name, num_samples in class_distribution.items():
    if num_samples < threshold:
        # Calculate augmentation factor needed for this class
        augmentation_factor = int(np.ceil(threshold / num_samples))

        # Load mel spectrograms for the underrepresented class
        class_dir = os.path.join(data_dir, 'train', class_name)  # Adjust the path here
        for mel_file in os.listdir(class_dir):
            mel_path = os.path.join(class_dir, mel_file)
            spectrogram = io.read_image(mel_path, mode=io.ImageReadMode.GRAY)

            # Apply augmentation
            augmented_spectogram = spectrogram.clone()
            for i in range(augmentation_factor):

                # augmented_spectrogram = time_stretch(augmented_spectrogram, factor=augmentation_params["time_stretch_factor"])
                # augmented_spectrogram = pitch_shift(augmented_spectrogram, shift=augmentation_params["pitch_shift_amount"])
                # augmented_spectrogram = noise_injection(augmented_spectrogram, noise_level=augmentation_params["noise_level"])
                # augmented_spectrogram = frequency_masking(augmented_spectrogram, num_masks=augmentation_params["freq_masking"], mask_factor=augmentation_params["freq_mask_width"])
                # augmented_spectrogram = time_warp(augmented_spectrogram, warp_factor=augmentation_params["time_warp_factor"])

                time_masking = T.TimeMasking(time_mask_param = 80)
                freq_masking = T.FrequencyMasking(freq_mask_param=80)

                augmented_spectogram = time_masking(augmented_spectogram)
                augmented_spectogram = freq_masking(augmented_spectogram)

                output_mel_path = os.path.join(output_dir, class_name, f"augmented_{mel_file}_{i}.png")
                io.write_image(output_mel_path, augmented_spectrogram)

AttributeError: module 'torchvision.io' has no attribute 'write_image'