In [2]:
#!conda env update --file environment.yml --prune

In [1]:
import torch, torchaudio
import torch.nn as nn
import torchaudio.functional as audioF

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
pd.options.display.max_seq_items = 2000

import os, random, glob, logging, ntpath, math
logging.basicConfig()
logger=logging.getLogger("dbg")
logger.setLevel(logging.DEBUG)

from IPython.display import Audio
import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Utility

In [3]:
logger.debug(torchaudio.list_audio_backends())

DEBUG:dbg:['ffmpeg']


In [4]:
import pystoi
import pesq


def combine_audio(speech: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor | int) -> torch.Tensor:
    if not (torch.is_floating_point(speech) or torch.is_complex(speech)):
        # speech = torch.tensor(speech, dtype=torch.float64, device=speech.device)
        speech = speech.to(torch.float64,non_blocking=True)
    if not (torch.is_floating_point(noise) or torch.is_complex(noise)):
        # noise = torch.tensor(noise, dtype=torch.float64, device=noise.device)
        noise = noise.to(torch.float64,non_blocking=True)
    if not(type(snr) is torch.Tensor):
        snr = torch.tensor([snr])
    logger.debug(f"speech:{speech.ndim}, noise:{noise.ndim}, snr:{snr.ndim}")
    out = audioF.add_noise(speech, noise, snr)
    return out

def calc_pesq(speech: np.ndarray, processed: np.ndarray) -> float:
    return pesq.pesq(ref=speech, deg=processed, fs=16000)

def calc_stoi(speech: np.ndarray, processed: np.ndarray) -> float:
    return pystoi.stoi(x=speech, y=processed, fs_sig=16000)

In [5]:
def plot_waveform(waveform, sample_rate=16000):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")

In [79]:
from torch.utils.data import Dataset, DataLoader

ldrnd = random.Random(42)   #   Used for noise loading 

def get_sequential_wav_paths(dir):
    count = len(glob.glob("*.wav", root_dir=dir))
    lst = []
    for i in range(1,count+1):
        lst.append(dir + "/" + str(i) + ".wav")
    
    return lst

class AudioDataset(Dataset):
    def __init__(self, data: list, root_dir: str | None = None):
        self.data = data
        if root_dir==None:
            root_dir = os.getcwd()+"/data"
        self.root_dir = root_dir
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave

    def get(self,idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave
    
class SortedSeqShuffledBatch(Dataset):
    def __init__(self, mixed: list, clean: list, batch_size: int = 1):
        self.batch_size = batch_size
        self.mixed = mixed
        self.clean = clean
        # if root_dir==None:
        #     root_dir = os.getcwd()+"/data"
        # self.root_dir = root_dir

    def __len__(self):
        return math.floor(len(self.mixed)/self.batch_size)
    
    def __getitem__(self,idx):
        mixeds = []
        cleans = []
        max_len = 0
        for i in range(self.batch_size):
            mixed_wave, _ = torchaudio.load(self.mixed[idx*self.batch_size+i])
            clean_wave, _ = torchaudio.load(self.clean[idx*self.batch_size+i])
            mixed_wave = mixed_wave[0]
            clean_wave = clean_wave[0]
            assert(mixed_wave.shape[0]==clean_wave.shape[0])
            if i==0:
                max_len = mixed_wave.shape[0]
            else:
                
                mixed_wave = torch.nn.functional.pad(mixed_wave,(0,max_len-(mixed_wave.shape[0])),value=0.0)
                clean_wave = torch.nn.functional.pad(clean_wave,(0,max_len-(clean_wave.shape[0])),value=0.0)
            # logger.debug(mixed_wave)
            mixeds.append(mixed_wave)
            cleans.append(clean_wave)
        mixeds = np.asarray(mixeds)
        cleans = np.asarray(cleans)
        return torch.from_numpy(mixeds).to(device=device), torch.from_numpy(cleans).to(device=device)


# class AudioCombiner():
#     def __init__(self, speech_ds: AudioDataset, noise_ds: AudioDataset):
#         self.idx = 0
#         self.speech_ds = speech_ds
#         self.noise_ds = noise_ds

#     def next(self, snr):
#         self.idx += 1
#         if self.idx >= len(self.speech_ds):
#             raise IndexError(f"Index {self.idx} out of bounds of speech dataset ({len(self.speech_ds)}).")
#         speech = self.speech_ds.__getitem__(self.idx)
#         noise = self.noise_ds.__getitem__(ldrnd.randint(0,len(self.noise_ds)-1))

#         return audioF.add_noise(speech, noise, torch.tensor(snr))

#     def set_index(self, idx):
#         self.idx = idx

SHUFFLE = False
train_dataset = SortedSeqShuffledBatch(get_sequential_wav_paths("data/mixed/train"),get_sequential_wav_paths("data/speech_ordered/train"),batch_size=4)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=SHUFFLE)

In [None]:
t_batch, _ = next(iter(train_dataloader))
print(t_batch.shape)
print(t_batch[0].shape)

torch.Size([1, 4, 522320])
torch.Size([4, 522320])


# Models

## SEGAN

In [None]:
from models import SEGAN

def segan_train(models: tuple):
    generator, discriminator = models

## WaveCRN

In [None]:
from models import WaveCRN

def wavecrn_train(model: nn.Module):


## RHR-Net

In [None]:
import yaml
from models.rhrnetdir.Arg_Parser import Recursive_Parse
from models import RHRNet
hp = Recursive_Parse(yaml.load(
    open('models/rhrnetdir/rhrnet_hyperparameters.yaml', encoding='utf-8'),
    Loader=yaml.Loader
    ))  
net = RHRNet(hp)
net(torch.randn(3, 1, 1024))

def rhrnet_train(model):


## Wave-U-Net

In [None]:
from models import WaveUNet

def waveunet_train(model):
    