In [1]:
#!conda env update --file environment.yml --prune

In [2]:
import torch, torchaudio
import torch.nn as nn
import torchaudio.functional as audioF
from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator, EventEnum
from ignite.metrics import Accuracy, Loss, Metric

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
pd.options.display.max_seq_items = 2000

import os, random, glob, logging, ntpath, math
logging.basicConfig()
logger=logging.getLogger("dbg")
logger.setLevel(logging.DEBUG)

from IPython.display import Audio
import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.debug(device)
logger.debug(torch.__version__)

DEBUG:dbg:cuda:0
DEBUG:dbg:2.6.0+cu124


# Utility

In [3]:
logger.debug(torchaudio.list_audio_backends())

DEBUG:dbg:['ffmpeg']


In [4]:
import pystoi
import pesq

def combine_audio(speech: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor | int) -> torch.Tensor:
    if not (torch.is_floating_point(speech) or torch.is_complex(speech)):
        # speech = torch.tensor(speech, dtype=torch.float64, device=speech.device)
        speech = speech.to(torch.float64,non_blocking=True)
    if not (torch.is_floating_point(noise) or torch.is_complex(noise)):
        # noise = torch.tensor(noise, dtype=torch.float64, device=noise.device)
        noise = noise.to(torch.float64,non_blocking=True)
    if not(type(snr) is torch.Tensor):
        snr = torch.tensor([snr])
    logger.debug(f"speech:{speech.ndim}, noise:{noise.ndim}, snr:{snr.ndim}")
    out = audioF.add_noise(speech, noise, snr)
    return out

def calc_pesq(speech: np.ndarray, processed: np.ndarray) -> float:
    return pesq.pesq(ref=speech, deg=processed, fs=16000)

def calc_stoi(speech: np.ndarray, processed: np.ndarray) -> float:
    return pystoi.stoi(x=speech, y=processed, fs_sig=16000)


def plot_waveform(waveform, sample_rate=16000):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")

In [5]:
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 4
SHUFFLE = False
FRAME_SHIFT = 40

# ldrnd = random.Random(42)   #   Used for noise loading 

def get_sequential_wav_paths(dir):
    count = len(glob.glob("*.wav", root_dir=dir))
    lst = []
    for i in range(1,count+1):
        lst.append(dir + "/" + str(i) + ".wav")
    
    return lst

class AudioDataset(Dataset):
    def __init__(self, data: list, root_dir: str | None = None):
        self.data = data
        if root_dir==None:
            root_dir = os.getcwd()+"/data"
        self.root_dir = root_dir
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave

    def get(self,idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave
    
class SortedSeqShuffledBatch(Dataset):
    def __init__(self, mixed: list, clean: list, batch_size: int = 1):
        self.batch_size = batch_size
        self.mixed = mixed
        self.clean = clean
        # if root_dir==None:
        #     root_dir = os.getcwd()+"/data"
        # self.root_dir = root_dir

    def __len__(self):
        return math.floor(len(self.mixed)/self.batch_size)
    
    def __getitem__(self,idx):
        mixeds = []
        cleans = []
        max_len = 0
        for i in range(self.batch_size):
            if idx + i > len(self.mixed): continue
            mixed_wave, _ = torchaudio.load(self.mixed[idx*self.batch_size+i])
            clean_wave, _ = torchaudio.load(self.clean[idx*self.batch_size+i])
            mixed_wave = mixed_wave[0]
            clean_wave = clean_wave[0]
            assert(mixed_wave.shape[0]==clean_wave.shape[0])
            if i==0:
                max_len = mixed_wave.shape[0]
            else:
                mixed_wave = torch.nn.functional.pad(mixed_wave,(0,max_len-(mixed_wave.shape[0])),value=0.0)
                clean_wave = torch.nn.functional.pad(clean_wave,(0,max_len-(clean_wave.shape[0])),value=0.0)
            # logger.debug(mixed_wave)
            mixeds.append(mixed_wave)
            cleans.append(clean_wave)
        mixeds = np.asarray(mixeds)
        cleans = np.asarray(cleans)
        return torch.tensor(mixeds).unsqueeze(1), torch.tensor(cleans).unsqueeze(1)

    def split(self, val_pct, seed=None):
        rnd = random.Random()
        if seed is not None:
            rnd.seed(seed)
        this_len = len(self)
        val_batches = math.floor(this_len*val_pct)
        val_indices = sorted(rnd.sample(range(this_len), val_batches))
        train_mixed = []
        train_clean = []
        val_mixed = []
        val_clean = []
        for i in range(this_len):
            if i in val_indices:
                for x in range(self.batch_size):
                    val_mixed.append(self.mixed[i*self.batch_size+x])
                    val_clean.append(self.clean[i*self.batch_size+x])
            else:
                for x in range(self.batch_size):
                    train_mixed.append(self.mixed[i*self.batch_size+x])
                    train_clean.append(self.clean[i*self.batch_size+x])
        
        return SortedSeqShuffledBatch(train_mixed,train_clean,self.batch_size), SortedSeqShuffledBatch(val_mixed,val_clean,self.batch_size)


class FrameLoaderEvents(EventEnum):
    END_OF_SAMPLE = "end_of_sample"

class FrameLoader():
    '''Takes a dataloader, frame size and frame shift. It can then be iterated over to produce frames.\n
    Provides padding when sample length would be exceeded.'''

    def __init__(self, dl: DataLoader, frame_size: int, frame_shift: int, engine: Engine | None = None):
        self.iter = iter(dl)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.batch_mixed: torch.Tensor
        self.batch_clean: torch.Tensor
        self.frame_position = 0
        self.at_end = True
        self.engine = engine
    def __iter__(self):
        return self
    def __next__(self):
        if self.at_end:
            self.batch_mixed, self.batch_clean = next(self.iter)
            self.batch_mixed.squeeze_(0)
            self.batch_clean.squeeze_(0)
            self.frame_position = 0
            self.at_end = False
          
        for batch in (self.batch_mixed, self.batch_clean):
            frame_end = self.frame_position + self.frame_size
            shp = batch.shape
            logger.debug(f"4: {batch.shape}")
            if frame_end >= shp[2]:
                if self.engine is not None: self.engine.fire_event(FrameLoaderEvents.END_OF_SAMPLE)
                self.at_end = True
                if frame_end != shp[2]:
                    diff = frame_end - shp[2]
                    #Pad batch until aligned with frame_end
                    for i in range(len(batch)):
                        batch[0][i] = torch.nn.functional.pad(batch[0][i],(0,diff),value=0.0)
                    logger.debug(f"5: {batch.shape}")
        
        self.frame_position += self.frame_shift
        logger.debug(f"end: {self.batch_mixed.shape}")
        return self.batch_mixed, self.batch_clean


dataset = SortedSeqShuffledBatch(get_sequential_wav_paths("data/mixed/train"), get_sequential_wav_paths("data/speech_ordered/train"), batch_size=BATCH_SIZE)
train_dataset, val_dataset = dataset.split(0.2,42)
del dataset
base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)

In [6]:
print(len(train_dataset))

print(len(val_dataset))

20
5


# Models

In [7]:
criterion = nn.MSELoss()
val_metrics : dict[str, Metric]= {
    "accuracy": Accuracy(),
    "loss": Loss(criterion)
}
def register_custom_events(eng: Engine):
    eng.register_events(*FrameLoaderEvents)



## SEGAN

In [None]:
from models.SEGAN import Generator, Discriminator

def segan_train(models: tuple):
    generator, discriminator = models

## WaveCRN

In [8]:
from models.WaveCRN import ConvBSRU
logger.debug("wavecrn loaded")
CRN_FRAME_SIZE = 96


crn_model = ConvBSRU(frame_size=CRN_FRAME_SIZE, conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
crn_optimizer = torch.optim.Adam(crn_model.parameters())


def wavecrn_train_step(engine, batch):
    crn_model.train()
    crn_optimizer.zero_grad()
    x, y = batch[0].to(device), batch[1].to(device)
    y_pred = crn_model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    crn_optimizer.step()
    return loss.item()

trainer = Engine(wavecrn_train_step)
register_custom_events(trainer)

def crn_validation_step(engine, batch):
    crn_model.eval()
    with torch.no_grad():
        x, y = batch[0].to(device), batch[1].to(device)
        y_pred = crn_model(x)
        return y_pred, y

crn_train_evaluator = Engine(crn_validation_step)
crn_val_evaluator = Engine(crn_validation_step)

for name, metric in val_metrics.items():
    metric.attach(crn_train_evaluator, name)

crn_train_dataloader = FrameLoader(base_train_dataloader, CRN_FRAME_SIZE, FRAME_SHIFT, trainer)
a,b = next(iter(crn_train_dataloader))
print(a.shape)

trainer.run(crn_train_dataloader)




If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Error building extension 'sru_cuda': [1/2] C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\nvcc --generate-dependencies-with-compile --dependency-output sru_cuda_kernel.cuda.o.d -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4068 -Xcompiler /wd4067 -Xcompiler /wd4624 -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -DTORCH_EXTENSION_NAME=sru_cuda -DTORCH_API_INCLUDE_EXTENSION_H -Id:\Anaconda\Miniconda3\envs\fyp\Lib\site-packages\torch\include -Id:\Anaconda\Miniconda3\envs\fyp\Lib\site-packages\torch\include\torch\csrc\api\include -Id:\Anaconda\Miniconda3\envs\fyp\Lib\site-packages\t

torch.Size([4, 1, 514320])


ERROR:ignite.engine.engine.Engine:Current run is terminating due to exception: Caught an unknown exception!
ERROR:ignite.engine.engine.Engine:Engine run is terminating due to exception: Caught an unknown exception!


RuntimeError: Caught an unknown exception!

## RHR-Net

In [None]:
import yaml
from models.rhrnetdir.Arg_Parser import Recursive_Parse
from models.RHRNet import RHRNet
hp = Recursive_Parse(yaml.load(
    open('models/rhrnetdir/rhrnet_hyperparameters.yaml', encoding='utf-8'),
    Loader=yaml.Loader
    ))  
net = RHRNet(hp)
net(torch.randn(3, 1, 1024))

def rhrnet_train(model):


## Wave-U-Net

In [None]:
from models.WaveUNet import Waveunet
waveunet_model = Waveunet()

def waveunet_train(model):
    