This is an implementation of MetricGAN in pytorch.
The specific model architecture and hyperparameters are taken from [1], which was implemented in keras.

[1] https://github.com/JasonSWFu/MetricGAN

In [None]:
# pip installs go here
! pip install pystoi
! pip install transformers
! pip install pydub

Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Installing collected packages: pydub
Successfully installed pydub-0.25.1


In [None]:
# imports go here
from google.colab import drive
import os
import librosa
import soundfile as sf
import numpy as np
import math
import scipy
import time
import datetime
import re
import pickle
import subprocess

import torch
from torch import nn
from torch.nn.utils.parametrizations import spectral_norm
from torch.utils.data import TensorDataset, DataLoader, Dataset, RandomSampler
from torch.nn.utils import clip_grad_norm_
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pystoi import stoi

from pydub import AudioSegment

In [None]:
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
# Global variables go here
DATASET_DIR = "/content/gdrive/MyDrive/MS-SNSD-dataset-30"
SCALE_FACTOR = 10
MASK_MIN_VALUE = 0.05
TARGET = 1                                                  # 0-1 range of clean-ness
BATCH_SIZE = 1
CHECKPT_DIR = "/content/gdrive/MyDrive/se-checkpoints/"
NUM_GAN_EPOCHS = 10              # Original paper uses 200
NUM_DISCRIMINATOR_EPOCHS = 2    # Number of discriminator epochs in each GAN epoch [15 before] 
NUM_GENERATOR_EPOCHS = 2        # Number of generator epochs in each GAN epoch [40 before]
FORCE_RESTART = False           # Restart training from epoch 0
RESUME_FROM = 8                 # Epoch number (0-indexed) to resume from if FORCE_RESTART is False
CONTINUE = False                # Load epoch number but not model states

In [None]:
def get_MS_file_pairs(root_dir, split='train', snrs=[0.0, 10.0, 20.0], high=30.0):
  clean_dir = os.path.join(root_dir, split+'/clean')
  noisy_dir = os.path.join(root_dir, split+'/noisy')
  data = []
  clean_data = []
  high_data = []
  for fname in os.listdir(clean_dir):
    if not (fname.startswith('clnsp') and fname.endswith('.wav')):
      continue
    example_number = int(fname[5:-4])
    for snr in snrs:
      noisy_name = "noisy{}_SNRdb_{:.1f}_clnsp{}.wav".format(example_number, snr, example_number)
      if os.path.isfile(os.path.join(noisy_dir, noisy_name)):
        data.append((os.path.join(clean_dir, fname), os.path.join(noisy_dir, noisy_name)))
        clean_data.append((os.path.join(clean_dir, fname), os.path.join(clean_dir, fname)))
    noisy_name = "noisy{}_SNRdb_{:.1f}_clnsp{}.wav".format(example_number, high, example_number)
    if os.path.isfile(os.path.join(noisy_dir, noisy_name)):
      high_data.append((os.path.join(clean_dir, fname), os.path.join(noisy_dir, noisy_name)))
  clean_data = list(set(clean_data))
  return data, clean_data, high_data

def wav_to_spectrogram(wav, normalize=False):
  """
  Given a wav file read in by librosa, performs STFT, then optionally normalizes the result.
  Returns the magnitude, the phase of the STFT, and signal length
  """
  orig_length = wav.shape[0]
  n_fft = 512                                                                   # Window size *after* padding with zeros
  wav_padded = librosa.util.fix_length(wav, orig_length + (n_fft//2))           # Pad the signal for FFT
  epsilon = 1e-12

  stft = librosa.stft(wav_padded, n_fft=n_fft, hop_length=(n_fft//2), win_length=n_fft, window=scipy.signal.hamming)
  result = np.abs(stft)
  phase = np.angle(stft)

  if normalize:
    mean = np.mean(result, axis=1).reshape((257,1))
    std = np.std(result, axis=1).reshape((257,1)) + epsilon
    result = (result-mean)/std
  
  result = np.reshape(result.T, (result.shape[1], 257))
  return result, phase, orig_length

def spectrogram_to_wav(stft, phase, signal_length):
  """
  Convert a spectrogram back to the original audio
  """
  scaled = np.multiply(stft, np.exp(1j*phase)) # Reconstruct the stft result from abs and phase
  result = librosa.istft(scaled, hop_length=256, win_length=512, window=scipy.signal.hamming, length=signal_length)
  return result

def format_time(elapsed):
  elapsed_rounded = int(round(elapsed))
  return str(datetime.timedelta(seconds=elapsed_rounded))

In [None]:
file_pairs, clean_pairs, high_pairs = get_MS_file_pairs(DATASET_DIR)
new_pairs = []

In [None]:
example = librosa.load(file_pairs[0][1], sr=16000)
result, phase, orig_length = wav_to_spectrogram(example[0])
print("Example shape of result: ", result.shape)
print("Example shape of phase: ", phase.shape)

Example shape of result:  (775, 257)
Example shape of phase:  (257, 775)


In [None]:
print(file_pairs[0])

('/content/gdrive/MyDrive/MS-SNSD-dataset-30/train/clean/clnsp174.wav', '/content/gdrive/MyDrive/MS-SNSD-dataset-30/train/noisy/noisy174_SNRdb_0.0_clnsp174.wav')


In [None]:
class Generator(nn.Module):
  def __init__(self, in_dim=257, out_dim=200):
    super().__init__()
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.lstms = nn.LSTM(input_size=self.in_dim, hidden_size=self.out_dim, num_layers=2, batch_first=True, bidirectional=True)
    self.linear1 = nn.Linear(in_features=2*self.out_dim, out_features=300)
    self.leaky_relu = nn.LeakyReLU()
    self.dropout = nn.Dropout(p=0.05)
    self.linear2 = nn.Linear(in_features=300, out_features=257)
    self.sigmoid = nn.Sigmoid()
  
  def forward(self, x):
    lstm_out, _ = self.lstms(x)
    layer1_out = self.dropout(self.leaky_relu(self.linear1(lstm_out)))
    layer2_out = self.sigmoid(self.linear2(layer1_out))
    return layer2_out

In [None]:
class Discriminator(nn.Module):
  def __init__(self, in_dim=257):
    super().__init__()
    # We are passed (batch_size, 2, n_frames, in_dim=257) as input -> since we need both clean and noisy
    # Note that there is no 'channels_last' feature in pytorch
    self.in_dim = in_dim
    self.batch_norm = nn.BatchNorm2d(num_features=2)
    self.conv2d_sn1 = spectral_norm(nn.Conv2d(in_channels=2, out_channels=15, kernel_size=(5,5), padding='valid'))
    self.leaky_relu1 = nn.LeakyReLU()
    self.conv2d_sn2 = spectral_norm(nn.Conv2d(in_channels=15, out_channels=35, kernel_size=(7,7), padding='valid'))
    self.leaky_relu2 = nn.LeakyReLU()
    self.conv2d_sn3 = spectral_norm(nn.Conv2d(in_channels=35, out_channels=65, kernel_size=(9,9), padding='valid'))
    self.leaky_relu3 = nn.LeakyReLU()
    self.conv2d_sn4 = spectral_norm(nn.Conv2d(in_channels=65, out_channels=90, kernel_size=(11,11), padding='valid'))
    self.leaky_relu4 = nn.LeakyReLU()
    # pytorch has no global average pooling layer (i.e. (channels, h, w) -> channels)
    # use AdaptiveAvgPool2d to get (channels, 1, 1) then flatter
    self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1))
    self.flatten = nn.Flatten()         # Now output should be batch_size x 50
    self.linear1 = spectral_norm(nn.Linear(in_features=90, out_features=50))
    self.leaky_relu5 = nn.LeakyReLU()
    self.linear2 = spectral_norm(nn.Linear(in_features=50, out_features=10))
    self.leaky_relu6 = nn.LeakyReLU()
    self.linear3 = spectral_norm(nn.Linear(in_features=10, out_features=1))
    self.sigmoid = nn.Sigmoid()
    self.std = 0.1

  def std_step(self):
    self.std = self.std * 0.9

  def forward(self, x):
    x_normalized = self.batch_norm(x)
    # x_normalized = x_normalized + (self.std**0.5)*torch.randn(x_normalized.shape).to(device)
    conv1_out = self.leaky_relu1(self.conv2d_sn1(x_normalized))
    conv2_out = self.leaky_relu2(self.conv2d_sn2(conv1_out))
    conv3_out = self.leaky_relu3(self.conv2d_sn3(conv2_out))
    conv4_out = self.leaky_relu4(self.conv2d_sn4(conv3_out))
    global_pool_out = self.flatten(self.global_avg_pool(conv4_out))
    linear1_out = self.leaky_relu5(self.linear1(global_pool_out))
    linear2_out = self.leaky_relu6(self.linear2(linear1_out))
    out = self.linear3(linear2_out)
    out = self.sigmoid(out)
    return out

In [None]:
def get_path_for_generator(path, epoch, create=False):
  """
  Given a path to the noisy wav file, returns the name/path that should be given to the generator's 
  output wav file in the i-th training epoch
  """
  file_name = path.split('/')[-1]
  if create:
    if not os.path.exists('/content/gdrive/MyDrive/SE-training/epoch{}'.format(epoch)):
      os.mkdir('/content/gdrive/MyDrive/SE-training/epoch{}'.format(epoch))
  return '/content/gdrive/MyDrive/SE-training/epoch{}/{}'.format(epoch, file_name)

def get_generator_sample(file_pair):
  """
  Given a file pair for (clean, noisy), reads the audio in and creates an appropriate training/test sample for the Generator.
  It seems with noisy audio, librosa clips off some audio because its conversion from mel to stft is lossy.
  See here: https://stackoverflow.com/questions/60365904/reconstructing-audio-from-a-melspectrogram-has-some-clipping-with-librosa
  Thus, we multiply the noisy audio by a constant (10), and later scale the output down by the same amount.
  """
  clean_file, noisy_file = file_pair
  noisy_wav, _ = librosa.load(noisy_file, sr=16000)
  noisy_spectrogram_normalized, _, _ = wav_to_spectrogram(noisy_wav*SCALE_FACTOR, normalize=True)
  noisy_spectrogram, phase, length = wav_to_spectrogram(noisy_wav*SCALE_FACTOR)

  clean_wav, _ = librosa.load(clean_file, sr=16000)
  clean_spectrogram, _, _ = wav_to_spectrogram(clean_wav)

  # The spectrograms now have the shape, (num_frames, frame_dim)
  # which is what we want to give to the generator, since it expects (batch_size, seq_length, input_size)
  # when batch_first=True is passed
  noisy_spectrogram_normalized = torch.from_numpy(noisy_spectrogram_normalized)
  noisy_spectrogram = torch.from_numpy(noisy_spectrogram)
  clean_spectrogram = torch.from_numpy(clean_spectrogram)
  mask = MASK_MIN_VALUE * torch.ones((noisy_spectrogram.shape[0], 257))

  return noisy_spectrogram_normalized, noisy_spectrogram, clean_spectrogram, mask, phase, length
  
def get_discriminator_sample(file_pair):
  """
  The analogous function for the discriminator. Here we pass in a 'clean' sample and a corresponding
  'noisy' sample -- except, the noisy sample may also be clean. We want to train the disciminator to give a score close to 1
  for clean samples and a score close to 0 for noisy ones. Thus, the 'noisy' sample may also be clean. If it is not, it needs
  to be scaled by the scale factor as usual. Whether it is found by checking whether 'SNRdb' appears in its name.
  """
  clean_file, noisy_file = file_pair
  noisy_wav, _ = librosa.load(noisy_file, sr=16000)
  # if 'SNRdb' in noisy_file:
  # Actually 'noisy' -- need the scale factor
  noisy_spectrogram, _, _ = wav_to_spectrogram(noisy_wav*SCALE_FACTOR)
  # else:
  #   noisy_spectrogram, _, _ = wav_to_spectrogram(noisy_wav)
  clean_wav, _ = librosa.load(clean_file, sr=16000)
  clean_spectrogram, phase, sr = wav_to_spectrogram(clean_wav)
  true_stoi_noisy = torch.tensor([float(stoi(x=clean_wav, y=noisy_wav, fs_sig=16000, extended=False))])

  # both spectrograms are of the shape (1, n_frames, 257) now
  input_np_noisy = np.stack((noisy_spectrogram, clean_spectrogram), axis=-1)
  input_torch_noisy = torch.from_numpy(input_np_noisy)

  # Now the input is of shape (n_frames, 257, 2) - we need it to be (2, n_frames, 257)
  input_torch_noisy = input_torch_noisy.permute(2,0,1)
  return input_torch_noisy, true_stoi_noisy, phase, sr

In [None]:
class GeneratorDataset(Dataset):
  def __init__(self, file_pairs):
    super().__init__()
    self.file_pairs = file_pairs
  
  def __len__(self):
    return len(self.file_pairs)
  
  def __getitem__(self, idx):
    return get_generator_sample(self.file_pairs[idx])

class DiscriminatorDataset(Dataset):
  def __init__(self, file_pairs):
    super().__init__()
    self.file_pairs = file_pairs
  
  def __len__(self):
    return len(self.file_pairs)
  
  def __getitem__(self, idx):
    return get_discriminator_sample(self.file_pairs[idx])

In [None]:
def get_max_checkpt(checkpt_dir):
  max_checkpt = 0
  for filename in os.listdir(checkpt_dir):
    if re.match(r"checkpt-gen-([0-9]+).pt", filename):
      checkpt_num = int(filename.split('.')[-2].split('-')[-1])
      if checkpt_num > max_checkpt:
        max_checkpt = checkpt_num
  return max_checkpt

def load_latest_checkpt(checkpt_dir=CHECKPT_DIR):
  global new_pairs, discriminator_dataset, discriminator_sampler, discriminator_dataloader
  if RESUME_FROM == -1:
    mx_checkpt = get_max_checkpt(checkpt_dir)
  else:
    mx_checkpt = RESUME_FROM
  if mx_checkpt > 0:
    gen_checkpt_file = os.path.join(checkpt_dir, "checkpt-gen-{}.pt".format(mx_checkpt))
    dis_checkpt_file = os.path.join(checkpt_dir, "checkpt-dis-{}.pt".format(mx_checkpt))
    genopt_checkpt_file = os.path.join(checkpt_dir, "checkpt-genopt-{}.pt".format(mx_checkpt))
    disopt_checkpt_file = os.path.join(checkpt_dir, "checkpt-disopt-{}.pt".format(mx_checkpt))
    generator.load_state_dict(torch.load(gen_checkpt_file))
    discriminator.load_state_dict(torch.load(dis_checkpt_file))
    generator_optimizer.load_state_dict(torch.load(genopt_checkpt_file))
    discriminator_optimizer.load_state_dict(torch.load(disopt_checkpt_file))
    new_pairs = pickle.load(open(os.path.join(CHECKPT_DIR, "npairs_{}.pkl".format(mx_checkpt)), 'rb'))
    discriminator_dataset = DiscriminatorDataset(new_pairs)
    discriminator_sampler = RandomSampler(discriminator_dataset)
    discriminator_dataloader = DataLoader(discriminator_dataset, sampler=discriminator_sampler, batch_size=BATCH_SIZE)
  return mx_checkpt

In [None]:
generator_dataset = GeneratorDataset(file_pairs)
discriminator_dataset = DiscriminatorDataset(file_pairs + clean_pairs)
generator_sampler = RandomSampler(generator_dataset)
discriminator_sampler = RandomSampler(discriminator_dataset)
generator_dataloader = DataLoader(generator_dataset, sampler=generator_sampler, batch_size=BATCH_SIZE)
discriminator_dataloader = DataLoader(discriminator_dataset, sampler=discriminator_sampler, batch_size=BATCH_SIZE)
generator_sample = next(iter(generator_dataloader))
discriminator_sample = next(iter(discriminator_dataloader))
print(generator_sample[0].shape, generator_sample[1].shape, generator_sample[2].shape, generator_sample[3].shape)
print(discriminator_sample[0].shape, discriminator_sample[1].shape)
generator = Generator()
discriminator = Discriminator()
generator_optimizer = AdamW(generator.parameters(), lr=1e-4, eps=1e-11)
discriminator_optimizer = AdamW(discriminator.parameters(), lr=2e-5, eps=1e-11)
if torch.cuda.is_available():
  print("Using GPU: {}".format(torch.cuda.get_device_name(0)))
  device = torch.device("cuda")
  discriminator.cuda()
  generator.cuda()
else:
  print("No GPUs available, using CPU")
  device = torch.device("cpu")

torch.Size([1, 669, 257]) torch.Size([1, 669, 257]) torch.Size([1, 669, 257]) torch.Size([1, 669, 257])
torch.Size([1, 2, 678, 257]) torch.Size([1, 1])
Using GPU: Tesla P100-PCIE-16GB


In [None]:
save = True
if FORCE_RESTART:
  start_epoch = 0
elif CONTINUE:
  # Start from RESUME_FROM without loading state
  start_epoch = RESUME_FROM
else:
  start_epoch = load_latest_checkpt() # 0-indexed
for gan_epoch in range(start_epoch, NUM_GAN_EPOCHS):
  print("<<<<<<<<<<<<<<< GAN epoch {} >>>>>>>>>>>>>>>>>".format(gan_epoch+1))
  discriminator.train()
  for epoch in range(NUM_DISCRIMINATOR_EPOCHS):
    val_sum = 0
    epoch_loss = 0
    epoch_start = time.time()
    true_stoi_avg = 0
    print("=============== Discriminator Epoch {} / {} =================".format(epoch+1, NUM_DISCRIMINATOR_EPOCHS))
    for step, batch in enumerate(discriminator_dataloader):
      discriminator.zero_grad()
      input_noisy_discriminator = batch[0].to(device)
      expected_out_noisy = batch[1].to(device)

      outputs_noisy = discriminator(input_noisy_discriminator)
      MSE = nn.MSELoss(reduction='sum')
      loss = MSE(outputs_noisy, expected_out_noisy)
      epoch_loss += loss
      loss.backward()
      val_sum += outputs_noisy[0][0]
      clip_grad_norm_(discriminator.parameters(), 1.0)
      discriminator_optimizer.step()
      true_stoi_avg += batch[1]
      if step % 10 == 0 and step != 0:
        elapsed = format_time(time.time() - epoch_start)
        noisy = input_noisy_discriminator[0,0,:,:].cpu().detach().numpy()
        clean = input_noisy_discriminator[0,1,:,:].cpu().detach().numpy()
        x = batch[2][0,:,:].cpu().detach().numpy()
        y = batch[3][0].cpu().detach().numpy()
        noisy_wav = spectrogram_to_wav(noisy.T, x, y) / SCALE_FACTOR
        clean_wav = spectrogram_to_wav(clean.T, x, y)
        print("Sample: {} v/s {} v/s {}".format(outputs_noisy[0][0], expected_out_noisy[0][0], stoi(x=clean_wav, y=noisy_wav, fs_sig=16000, extended=False)))
        print("Batch {} of {}. Elapsed {}".format(step, len(discriminator_dataloader), elapsed))
    avg_train_loss = epoch_loss / (step+1)
    true_stoi_avg = true_stoi_avg / (step+1)
    val_sum = val_sum / (step+1)
    print("Average discriminator training loss for epoch {} : {}".format(epoch+1, avg_train_loss))
    print("Average True STOI for generated outputs last epoch: {}".format(true_stoi_avg[0][0]))
    print("Epoch took {}".format(format_time(time.time()-epoch_start)))
    print("")

  discriminator.eval()
  generator.train()
  for epoch in range(NUM_GENERATOR_EPOCHS):
    epoch_loss = 0
    print("============= Generator Epoch {} / {} =================".format(epoch+1, NUM_GENERATOR_EPOCHS))
    epoch_start = time.time()
    for step, batch in enumerate(generator_dataloader):
      generator.zero_grad()
      input_generator = batch[0].to(device)
      noisy_audio = batch[1].to(device)
      clean = batch[2].to(device)
      min_mask = batch[3].to(device)
      target = torch.tensor([[2.0]]).to(device)

      output_generator = generator(input_generator)
      mask = torch.maximum(output_generator, min_mask)
      cleaned = torch.mul(mask, noisy_audio)
      stacked = torch.unsqueeze(torch.cat((cleaned, clean), axis=0), 0)
      discriminator_output = discriminator(stacked)
      MSE = nn.MSELoss(reduction='sum')
      loss = MSE(discriminator_output, target)
      epoch_loss += loss
      loss.backward()
      clip_grad_norm_(generator.parameters(), 1.0)
      generator_optimizer.step()
      if step % 10 == 0 and step != 0:
        elapsed = format_time(time.time() - epoch_start)
        cleaned = cleaned.squeeze().cpu().detach().numpy()
        clean = clean.squeeze().cpu().detach().numpy()
        x = batch[4][0,:,:].cpu().detach().numpy()
        y = batch[5][0].cpu().detach().numpy()
        cleaned_wav = spectrogram_to_wav(cleaned.T, x, y)  / SCALE_FACTOR
        clean_wav = spectrogram_to_wav(clean.T, x, y)
        print("Sample: {} v/s {}".format(discriminator_output[0][0], stoi(x=clean_wav, y=cleaned_wav, fs_sig=16000, extended=False)))
        print("Batch {} of {}. Elapsed {}".format(step, len(generator_dataloader), elapsed))
    avg_train_loss = epoch_loss / len(generator_dataloader)
    print("Average generator training loss for epoch {} : {}".format(epoch+1, avg_train_loss))
    print("Epoch took {}".format(format_time(time.time()-epoch_start)))
    print("")

  print("Saving new files for next epoch")
  generator.eval()
  if not os.path.exists('/content/gdrive/MyDrive/SE-training/epoch{}'.format(gan_epoch)):
      os.mkdir('/content/gdrive/MyDrive/SE-training/epoch{}'.format(gan_epoch))
  new_pairs = []
  avg_stoi = 0
  for file_pair in file_pairs:
    batch = get_generator_sample(file_pair)
    input_generator = batch[0].unsqueeze(0).to(device)
    noisy_audio = batch[1].unsqueeze(0).to(device)
    min_mask = batch[3].unsqueeze(0).to(device)

    new_pair_name = (file_pair[0], get_path_for_generator(file_pair[1], gan_epoch))
    output_generator = generator(input_generator)
    mask = torch.maximum(output_generator, min_mask)
    cleaned = torch.mul(mask, noisy_audio).squeeze().cpu().detach().numpy()
    cleaned_wav = spectrogram_to_wav(cleaned.T, batch[4], batch[5]) / SCALE_FACTOR
    orig_clean = librosa.load(file_pair[0], sr=16000)[0]
    s = stoi(x=orig_clean, y=cleaned_wav, fs_sig=16000, extended=False)
    avg_stoi += s
    sf.write(new_pair_name[1], cleaned_wav, 16000)
    new_pairs.append(new_pair_name)
  new_pairs += high_pairs
  avg_stoi /= len(file_pairs)
  print("Average STOI: {}".format(avg_stoi))
  # New dataset for discriminator
  discriminator_dataset = DiscriminatorDataset(new_pairs)
  discriminator_sampler = RandomSampler(discriminator_dataset)
  discriminator_dataloader = DataLoader(discriminator_dataset, sampler=discriminator_sampler, batch_size=BATCH_SIZE)

  if save:
      torch.save(generator.state_dict(), os.path.join(CHECKPT_DIR, "checkpt-gen-{}.pt".format(gan_epoch+1)))
      torch.save(discriminator.state_dict(), os.path.join(CHECKPT_DIR, "checkpt-dis-{}.pt".format(gan_epoch+1)))
      torch.save(generator_optimizer.state_dict(), os.path.join(CHECKPT_DIR, "checkpt-genopt-{}.pt".format(gan_epoch+1)))
      torch.save(discriminator_optimizer.state_dict(), os.path.join(CHECKPT_DIR, "checkpt-disopt-{}.pt".format(gan_epoch+1)))
      pickle.dump(new_pairs, open(os.path.join(CHECKPT_DIR, "npairs_{}.pkl".format(gan_epoch+1)), 'wb+'))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Batch 2580 of 6492. Elapsed 0:21:43
Sample: 0.8588351607322693 v/s 0.8629838824272156 v/s 0.8842645295535324
Batch 2590 of 6492. Elapsed 0:21:49
Sample: 0.8019071817398071 v/s 0.7491417527198792 v/s 0.8061253175149075
Batch 2600 of 6492. Elapsed 0:21:53
Sample: 0.9069784283638 v/s 0.9242422580718994 v/s 0.9371953693172869
Batch 2610 of 6492. Elapsed 0:21:58
Sample: 0.8859620094299316 v/s 0.9381535649299622 v/s 0.9560934278200134
Batch 2620 of 6492. Elapsed 0:22:03
Sample: 0.8872944712638855 v/s 0.8725707530975342 v/s 0.9222196722830637
Batch 2630 of 6492. Elapsed 0:22:08
Sample: 0.6976976990699768 v/s 0.6792864799499512 v/s 0.7507370829836862
Batch 2640 of 6492. Elapsed 0:22:13
Sample: 0.7672492861747742 v/s 0.8379219174385071 v/s 0.8639492286263397
Batch 2650 of 6492. Elapsed 0:22:18
Sample: 0.8607964515686035 v/s 0.8904073238372803 v/s 0.9170445185562012
Batch 2660 of 6492. Elapsed 0:22:23
Sample: 0.8314743041992188 v/s

KeyboardInterrupt: ignored

In [None]:
def run_generator_on_path(in_path, out_path=None):
  if out_path is None:
    out_path = in_path.split('.wav')[0] + '-out.wav'
  noisy_wav, _ = librosa.load(in_path, sr=16000)
  noisy_spectrogram_normalized, _, _ = wav_to_spectrogram(noisy_wav*SCALE_FACTOR, normalize=True)
  noisy_spectrogram, phase, length = wav_to_spectrogram(noisy_wav*SCALE_FACTOR)

  # The spectrograms now have the shape, (num_frames, frame_dim)
  # which is what we want to give to the generator, since it expects (batch_size, seq_length, input_size)
  # when batch_first=True is passed
  noisy_spectrogram_normalized = torch.from_numpy(noisy_spectrogram_normalized)
  noisy_spectrogram = torch.from_numpy(noisy_spectrogram)
  mask = MASK_MIN_VALUE * torch.ones((noisy_spectrogram.shape[0], 257))

  noisy_spectrogram_normalized = noisy_spectrogram_normalized.unsqueeze(0).to(device)
  noisy_spectrogram = noisy_spectrogram.unsqueeze(0).to(device)
  mask = mask.unsqueeze(0).to(device)
  output_generator = generator(noisy_spectrogram_normalized)
  mask = torch.maximum(output_generator, mask)
  cleaned = torch.mul(mask, noisy_spectrogram).squeeze().cpu().detach().numpy()
  cleaned_wav = spectrogram_to_wav(cleaned.T, phase, length) / SCALE_FACTOR
  sf.write(out_path, cleaned_wav, 16000)

In [None]:
# num = 19
# dblevel = 0.0
for num in range(1,61):
  for dblevel in [0.0, 10.0]:
    name_wav = "noisy{}_SNRdb_{:.1f}_clnsp{}.wav".format(num, dblevel, num)
    FROM_DIR = "/content/gdrive/MyDrive/MS-SNSD-dataset-30/test-main/noisy/"
    TO_DIR = "/content/gdrive/MyDrive/saved/tt1/test-main/"
    run_generator_on_path(FROM_DIR + name_wav, TO_DIR + name_wav)

In [None]:
def mix_audio_wavs(signal, noise, snr):
  # Source: https://stackoverflow.com/questions/71915018/mix-second-audio-clip-at-specific-snr-to-original-audio-file-in-python
  # if the audio is longer than the noise
  # play the noise in repeat for the duration of the audio
  noise = noise[np.arange(len(signal)) % len(noise)]

  # if the audio is shorter than the noi
  # this is important if loading resulted in 
  # uint8 or uint16 types, because it would cause overflow
  # when squaring and calculating mean
  noise = noise.astype(np.float32)
  signal = signal.astype(np.float32)

  # get the initial energy for reference
  signal_energy = np.mean(signal**2)
  noise_energy = np.mean(noise**2)
  # calculates the gain to be applied to the noise 
  # to achieve the given SNR
  g = np.sqrt(10.0 ** (-snr/10) * signal_energy / noise_energy)

  # Assumes signal and noise to be decorrelated
  # and calculate (a, b) such that energy of 
  # a*signal + b*noise matches the energy of the input signal
  a = np.sqrt(1 / (1 + g**2))
  b = np.sqrt(g**2 / (1 + g**2))
  # print(g, a, b)
  # mix the signals
  return signal + g * noise

def add_noise(sound_file, noise_file, snr=0.0, out_file=None):
  if out_file is None:
    out_file = sound_file.split('.wav')[0] + '-noisy.wav'
  sound_wav, _ = librosa.load(sound_file, sr=16000)
  noise_wav, _ = librosa.load(noise_file, sr=16000)
  mixed = mix_audio_wavs(sound_wav, noise_wav, snr)
  sf.write(out_file, mixed, 16000)

def convert_m4a_to_wav(m4a, wav=None):
  if wav is None:
    wav = m4a.split(".m4a")[0] + ".wav"
  track = AudioSegment.from_file(m4a,  format= 'm4a')
  file_handle = track.export(wav, format='wav')

def full_cycle(file_path, noise_path, snr=0.0, from_m4a=True):
  if from_m4a:
    convert_m4a_to_wav(file_path)
    file_path = file_path.split('.m4a')[0] + '.wav'
  add_noise(file_path, noise_path, snr)
  noisy_path = file_path.split('.wav')[0] + '-noisy.wav'
  run_generator_on_path(noisy_path)

def clean(file_path, from_m4a=True):
  if from_m4a:
    convert_m4a_to_wav(file_path)
    file_path = file_path.split('.m4a')[0] + '.wav'
  run_generator_on_path(file_path)

def clean_video(file_path):
  """
  Extracts audio from the video, cleans it, and then pastes it back on the video
  """
  audio_path = file_path.split(".mp4")[0] + ".wav"
  command = "ffmpeg -i {} -ab 160k -ac 2 -ar 44100 -vn {}".format(file_path, audio_path)
  subprocess.call(command, shell=True)
  clean(audio_path)
  cleaned_path = audio_path.split(".wav")[0] + "-out.wav"
  new_path = file_path.split(".mp4")[0] + "-cleaned.mp4"
  command = "ffmpeg -i {} -i {} -c:v copy -map 0:v:0 -map 1:a:0 {}".format(file_path, cleaned_path, new_path)
  subprocess.call(command, shell=True)

In [None]:
# noise_p = '/content/gdrive/MyDrive/saved/tt1/noise/Bus_1.wav'
noise_p = '/content/gdrive/MyDrive/saved/tt1/noise/VacuumCleaner_1.wav'
m4a_p = '/content/gdrive/MyDrive/saved/tt1/audio/shaila3.m4a'
# full_cycle(m4a_p, noise_p)
# clean(m4a_p)
clean_video("/content/gdrive/MyDrive/saved/tt1/audio/tomscott-esa.mp4")