In [1]:
#!conda env update --file environment.yml --prune

In [2]:
import torch, torchaudio
import torch.nn as nn
import torchaudio.functional as audioF
from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator, EventEnum
from ignite.metrics import Accuracy, Loss, Metric
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
from ignite.exceptions import NotComputableError

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# pd.options.display.max_seq_items = 2000
pd.set_option("display.max_colwidth",None)

import os, random, glob, logging, ntpath, math, time
logging.basicConfig()
logger=logging.getLogger("dbg")
logger.setLevel(logging.DEBUG)
perf_logger=logging.getLogger("perf")
perf_logger.setLevel(logging.DEBUG)
# logging.disable(logging.DEBUG)

from IPython import display
from IPython.display import Audio
from tqdm import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.debug(device)
logger.debug(torch.__version__)


DEBUG:dbg:cuda:0
DEBUG:dbg:2.6.0+cu124


In [3]:
BATCH_SIZE = 4
SHUFFLE = False
FRAME_SHIFT = 40

# Utility

In [4]:
logger.debug(torchaudio.list_audio_backends())

DEBUG:dbg:['ffmpeg']


In [None]:
import pystoi
import pesq

def combine_audio(speech: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor | int) -> torch.Tensor:
    if not (torch.is_floating_point(speech) or torch.is_complex(speech)):
        # speech = torch.tensor(speech, dtype=torch.float64, device=speech.device)
        speech = speech.to(torch.float64,non_blocking=True)
    if not (torch.is_floating_point(noise) or torch.is_complex(noise)):
        # noise = torch.tensor(noise, dtype=torch.float64, device=noise.device)
        noise = noise.to(torch.float64,non_blocking=True)
    if not(type(snr) is torch.Tensor):
        snr = torch.tensor([snr])
    logger.debug(f"speech:{speech.ndim}, noise:{noise.ndim}, snr:{snr.ndim}")
    out = audioF.add_noise(speech, noise, snr)
    return out

def calc_pesq(speech: np.ndarray, processed: np.ndarray) -> float:
    return pesq.pesq(ref=speech, deg=processed, fs=16000)

def calc_stoi(speech: np.ndarray, processed: np.ndarray) -> float:
    return pystoi.stoi(x=speech, y=processed, fs_sig=16000)

def ns_to_sec(ns: int):
    return ns/1000000000.0

def plot_waveform(waveform, sample_rate=16000):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")


def stitch_audio(frames_func, frame_size, frame_shift):
    *batches, at_end = frames_func()

    # audio[mix, clean]
    audio = [torch.zeros((b.shape[0],b.shape[1],640000),dtype=torch.float) for b in batches if len(b.shape)==3]
    pos=0
    end=frame_size
    for batch_i, batch in enumerate(batches):
        audio[batch_i][:,:,:end] = batch[:,:,:]

    frame_slice_start = frame_size - frame_shift
    while not at_end:
        pos = end
        end += frame_shift
        *batches, at_end = frames_func()
        for batch_i, batch in enumerate((batches)):
            audio[batch_i][:,0,pos:end] = batch[:,0,frame_slice_start:]

            if at_end:
                print(audio[batch_i].shape)
                audio[batch_i] = torch.tensor(audio[batch_i][:,:,:end]) 
        
    for i in range(BATCH_SIZE):
        print(audio[0][i][0])

    return audio

## Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader

# ldrnd = random.Random(42)   #   Used for noise loading 

def get_sequential_wav_paths(dir):
    count = len(glob.glob("*.wav", root_dir=dir))
    lst = []
    for i in range(1,count+1):
        lst.append(dir + "/" + str(i) + ".wav")
    
    return lst

class AudioDataset(Dataset):
    def __init__(self, data: list, root_dir: str | None = None):
        self.data = data
        if root_dir==None:
            root_dir = os.getcwd()+"/data"
        self.root_dir = root_dir
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave

    def get(self,idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave
    
class SortedBatchDataset(Dataset):
    def __init__(self, mixed: list, clean: list, batch_size: int = 1):
        self.batch_size = batch_size
        self.mixed = mixed
        self.clean = clean
        # if root_dir==None:
        #     root_dir = os.getcwd()+"/data"
        # self.root_dir = root_dir

    def __len__(self):
        return math.floor(len(self.mixed)/self.batch_size)
    
    def __getitem__(self,idx):
        mixeds = []
        cleans = []
        max_len = 0
        for i in range(self.batch_size):
            if idx + i > len(self.mixed): continue
            mixed_wave, _ = torchaudio.load(self.mixed[idx*self.batch_size+i])
            clean_wave, _ = torchaudio.load(self.clean[idx*self.batch_size+i])
            mixed_wave = mixed_wave[0]
            clean_wave = clean_wave[0]
            assert(mixed_wave.shape[0]==clean_wave.shape[0])
            if i==0:
                max_len = mixed_wave.shape[0]
            else:
                mixed_wave = torch.nn.functional.pad(mixed_wave,(0,max_len-(mixed_wave.shape[0])),value=0.0)
                clean_wave = torch.nn.functional.pad(clean_wave,(0,max_len-(clean_wave.shape[0])),value=0.0)
            # logger.debug(mixed_wave)
            mixeds.append(mixed_wave)
            cleans.append(clean_wave)
        mixeds = np.asarray(mixeds)
        cleans = np.asarray(cleans)
        return torch.tensor(mixeds).unsqueeze(1), torch.tensor(cleans).unsqueeze(1)

    def split(self, val_pct, seed=None):
        rnd = random.Random()
        if seed is not None:
            rnd.seed(seed)
        this_len = len(self)
        val_batches = math.floor(this_len*val_pct)
        val_indices = sorted(rnd.sample(range(this_len), val_batches))
        train_mixed = []
        train_clean = []
        val_mixed = []
        val_clean = []
        for i in range(this_len):
            if i in val_indices:
                for x in range(self.batch_size):
                    val_mixed.append(self.mixed[i*self.batch_size+x])
                    val_clean.append(self.clean[i*self.batch_size+x])
            else:
                for x in range(self.batch_size):
                    train_mixed.append(self.mixed[i*self.batch_size+x])
                    train_clean.append(self.clean[i*self.batch_size+x])
        
        return SortedBatchDataset(train_mixed,train_clean,self.batch_size), SortedBatchDataset(val_mixed,val_clean,self.batch_size)


class FrameLoaderEvents(EventEnum):
    END_OF_BATCH = "end_of_batch"

class FrameLoader():
    '''Takes a dataloader, frame size and frame shift. It can then be iterated over to produce frames.\n
    Provides padding when sample length would be exceeded.\n
    Returns (mix, clean, has_batch_ended)'''

    def __init__(self, dl: DataLoader, frame_size: int, frame_shift: int, engine: Engine | None = None):
        self.dl = dl
        self.dl_iter = iter(dl)
        self.batch_count = len(dl)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.batch_mixed: torch.Tensor
        self.batch_clean: torch.Tensor
        self.frame_position = 0
        self.at_end = True
        self.engine = engine
    def __iter__(self):
        self.dl_iter = iter(self.dl)
        return self
    def __next__(self):
        if self.at_end:
            batches: tuple[torch.Tensor,torch.Tensor] = next(self.dl_iter)
            self.batch_mixed = batches[0].to(device=device).squeeze_(0)
            self.batch_clean = batches[1].to(device=device).squeeze_(0)
            self.frame_position = 0
            self.at_end = False
            # logger.debug(f"mixed batch shape:{self.batch_mixed.shape} | clean batch shape:{self.batch_clean.shape}")
            # mix_maxes = [torch.max(s[0]) for s in self.batch_mixed]
            # clean_maxes = [torch.max(s[0]) for s in self.batch_clean]
            # logger.debug(f"mix_maxes:{str(mix_maxes)} | clean_maxes:{str(clean_maxes)}")
        
        frame_end = self.frame_position + self.frame_size
        frames = []
        for batch_i, batch in enumerate([self.batch_mixed, self.batch_clean]):
            shp = batch.shape
            frame: torch.Tensor
            if frame_end >= shp[2]:
                if self.engine is not None and batch_i==0: 
                    self.engine.fire_event(FrameLoaderEvents.END_OF_BATCH)
                    # logger.debug("emitted FrameLoaderEvents.END_OF_SAMPLE")
                self.at_end = True
                if frame_end != shp[2]:
                    diff = frame_end - shp[2]
                    # Pad batch until aligned with frame_end
                    frame = torch.zeros((BATCH_SIZE,1,self.frame_size),dtype=torch.float32).to(device="cpu")
                    frame[:, 0, 0:self.frame_size - diff] = batch[:, 0, self.frame_position:shp[2]]
                else:
                    frame = torch.tensor(batch[:,:,self.frame_position:frame_end],device=device)
            else:
                frame = torch.tensor(batch[:,:,self.frame_position:frame_end],device=device)
            frames.append(frame)

        self.frame_position += self.frame_shift
        # logger.debug(f"FrameLoader: frame_position[{self.frame_position}] - frame_end[{frame_end}]")
        # perf_logger.debug(f"Time to load frame at ({self.frame_position}): {time.perf_counter() - start}s")
        if frames[0].shape[2] != 96:
            logger.debug(frames[0].shape[2])
            logger.debug("FrameLoader issue")
        
        # print("frameloader - mix")
        # for smp in frames[0]:
        #     smp = smp[0]
        #     print(torch.max(smp))

        # print("frameloader - clean")
        # for smp in frames[1]:
        #     smp = smp[0]
        #     print(torch.max(smp))

        # print("\n")
        return frames[0], frames[1], self.at_end


_dataset = SortedBatchDataset(get_sequential_wav_paths("data/mixed/train"), get_sequential_wav_paths("data/speech_ordered/train"), batch_size=BATCH_SIZE)
train_dataset, val_dataset = _dataset.split(0.2,42)
del _dataset
base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)
base_val_dataloader = DataLoader(val_dataset)

In [7]:
print(len(train_dataset))

print(len(val_dataset))

8
2


# Models

In [8]:
class PESQMetric(Metric):
    def __init__(self, output_transform = lambda x: x, device=device):
        self._running_total=0.0
        self._num=0
        super().__init__(output_transform, device)
    @reinit__is_reduced
    def reset(self):
        self._running_total=0.0
        self._num=0
        super().reset()
    @reinit__is_reduced
    def update(self, output):
        y_pred: np.ndarray
        y: np.ndarray
        y_pred, y = output[0].cpu().numpy(), output[1].cpu().numpy()
        for i in range(BATCH_SIZE):
            self._running_total += calc_pesq(y[i][0], y_pred[i][0])
            self._num += 1
    @sync_all_reduce("_num","_running_total:SUM")
    def compute(self):
        if self._num == 0:
            raise NotComputableError("PESQ Metric must have one example before computing")
        return self._running_total / self._num

class STOIMetric(Metric):
    def __init__(self, output_transform = lambda x: x, device=device):
        self._running_total=0.0
        self._num=0
        super().__init__(output_transform, device)
    @reinit__is_reduced
    def reset(self):
        self._running_total=0.0
        self._num=0
        super().reset()
    @reinit__is_reduced
    def update(self, output):
        y_pred: np.ndarray
        y: np.ndarray
        y_pred, y = output[0].detach().numpy(), output[1].detach().numpy()
        for i in range(BATCH_SIZE):
            self._running_total += calc_stoi(y[i][0], y_pred[i][0])
            self._num += 1
    @sync_all_reduce("_num","_running_total:SUM")
    def compute(self):
        if self._num == 0:
            raise NotComputableError("STOI Metric must have one example before computing")
        return self._running_total / self._num

criterion = nn.MSELoss()
train_metrics: dict[str, Metric] = {
    "loss": Loss(criterion),
}
val_metrics: dict[str, Metric] = {
    "loss": Loss(criterion),
    "pesq": PESQMetric(),
    "stoi": STOIMetric()
}
def register_custom_events(eng: Engine):
    eng.register_events(*FrameLoaderEvents)

def log_trainer_loss(eng: Engine, **kwargs):
    prefix = kwargs.get("prefix","")
    print(f"{prefix}Epoch[{eng.state.epoch}], Iter[{eng.state.iteration}] Loss: {eng.state.output}")

def log_eval_results(eng: Engine, **kwargs):
    prefix = kwargs.get("prefix","")
    val_evaluator: Engine = kwargs.get("val_evaluator",None)
    val_frame_loader: FrameLoader = kwargs.get("val_frame_loader",None)
    if val_evaluator is None:
        raise TypeError("log_eval_results must be passed the argument `val_evaluator` of type `Engine`")
    if val_frame_loader is None:
        raise TypeError("log_eval_results must be passed the argument `val_frame_loader` of type `FrameLoader`")
    
    val_evaluator.run(val_frame_loader)
    metrics = val_evaluator.state.metrics
    print(f"{prefix}Epoch[{eng.state.epoch}] | Accuracy:[{metrics['accuracy']:.2f}] | PESQ:[{metrics['pesq']:.2f}] | STOI:[{metrics['stoi']:.2f}] | Loss:[{metrics['loss']}]")




## SEGAN

In [9]:
from models.segan import Generator, Discriminator

def segan_train(models: tuple):
    generator, discriminator = models

## WaveCRN

In [None]:
CRN_FRAME_SIZE = 96
CRN_TRAIN = False

logging.disable(logging.DEBUG)

from models.wavecrn import ConvBSRU
logger.debug("wavecrn loaded")

crn_model: ConvBSRU
crn_optimizer: torch.optim.Adam
if 'crn_model' in locals(): del crn_model
if 'crn_optimizer' in locals(): del crn_optimizer

torch.cuda.empty_cache()

crn_model = ConvBSRU(frame_size=CRN_FRAME_SIZE, conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
if not CRN_TRAIN:
    crn_model.load_state_dict(torch.load("saved_models/crn.pt", weights_only=True))
crn_optimizer = torch.optim.Adam(crn_model.parameters(),lr=1.0e-5)

def wavecrn_train_step(engine, batch):
    # s0 = time.perf_counter()
    crn_model.train()
    crn_optimizer.zero_grad()
    x, y = batch[0].to(device), batch[1].to(device)
    # s1 = time.perf_counter()
    y_pred = crn_model(x)
    # logger.debug(f"x_shape:{x.shape}, y_shape:{y.shape}, y_pred_shape={y_pred.shape}")
    # e1 = time.perf_counter()
    loss = criterion(y_pred, y)
    # s2 = time.perf_counter()
    loss.backward()
    crn_optimizer.step()
    # e2 = time.perf_counter()
    # perf_logger.debug(f"Time for [y_pred = crn_model(x)]:{e1-s1}s | Time for [backprop]:{e2-s2}s")
    # perf_logger.debug(f"Total in [wavecrn_train_step()]: {time.perf_counter()-s0}")
    
    return loss.item()

crn_trainer = Engine(wavecrn_train_step)
register_custom_events(crn_trainer)
crn_trainer.add_event_handler(FrameLoaderEvents.END_OF_BATCH,log_trainer_loss,prefix="Batch completed | ")
crn_trainer.add_event_handler(Events.ITERATION_COMPLETED(every=200),log_trainer_loss)

crn_train_dataloader = FrameLoader(base_train_dataloader, CRN_FRAME_SIZE, FRAME_SHIFT, crn_trainer)

def crn_validation_step(engine, batch):
    crn_model.eval()
    with torch.no_grad():
        x, y = batch[0].to(device), batch[1].to(device)
        y_pred = crn_model(x)
        return y_pred, y

crn_train_evaluator = Engine(crn_validation_step)
crn_val_evaluator = Engine(crn_validation_step)

for name, metric in train_metrics.items():
    metric.attach(crn_train_evaluator, name)
for name, metric in val_metrics.items():
    metric.attach(crn_val_evaluator, name)

crn_val_dataloader = FrameLoader(base_val_dataloader, CRN_FRAME_SIZE, FRAME_SHIFT)
# crn_trainer.add_event_handler(Events.EPOCH_COMPLETED,log_eval_results,val_evaluator=crn_val_evaluator,val_frame_loader=crn_val_dataloader)

if CRN_TRAIN:
    crn_trainer.run(crn_train_dataloader, max_epochs=10)



logging.disable(logging.NOTSET)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [None]:
torch.save(crn_model.state_dict(),"saved_models/crn.pt")

In [23]:
logging.disable(logging.NOTSET)
def crn_eval_contiguous(_iter):
    def auto_iterate_validation_step():
        batch = next(_iter)
        crn_model.eval()
        with torch.no_grad():
            x, y = batch[0].to(device), batch[1].to(device)
            y_pred = crn_model(x)
            return y_pred, y, batch[2]
    
    out = stitch_audio(auto_iterate_validation_step,CRN_FRAME_SIZE,FRAME_SHIFT)
    return out[0], out[1]

def test_stitching(_iter):
    def _get_batch():
        return next(_iter)
    
    out = stitch_audio(_get_batch,CRN_FRAME_SIZE,FRAME_SHIFT)
    return out[0], out[1]

_frame_iter = iter(crn_val_dataloader)
_frameloader_iter = iter(FrameLoader(base_val_dataloader, CRN_FRAME_SIZE, FRAME_SHIFT))

In [24]:
# mix, clean = crn_eval_contiguous(_frame_iter)
mix, clean = test_stitching(_frameloader_iter)
print(f"Mix shape: {mix.shape} | Clean shape: {clean.shape}")


for i in range(BATCH_SIZE):
    # print(f"Sample {i}: mix max={torch.max(mix[i][0])}, mix max={torch.max(clean[i][0])}")
    _mix = mix[i][0].numpy()
    _clean = clean[i][0].numpy()

    if _clean.max()==0:
        continue
    print(_mix)
    print("Mixed:")
    display.display(Audio(_mix,rate=16000))
    print("Clean:")
    display.display(Audio(_clean,rate=16000))

    print(f"PESQ: {calc_pesq(_clean, _mix)}")
    print(f"STOI: {calc_stoi(_clean, _mix)}")



torch.Size([4, 1, 640000])
torch.Size([4, 1, 640000])
tensor([-0.0039, -0.0050, -0.0035,  ...,  0.0000,  0.0000,  0.0000])
tensor([-0.0100, -0.0095, -0.0091,  ...,  0.0000,  0.0000,  0.0000])
tensor([-0.1530, -0.1476, -0.1427,  ...,  0.0000,  0.0000,  0.0000])
tensor([-0.0095,  0.0055, -0.0174,  ...,  0.0000,  0.0000,  0.0000])
Mix shape: torch.Size([4, 1, 23136]) | Clean shape: torch.Size([4, 1, 23136])
[-0.00390625 -0.0050354  -0.00354004 ...  0.          0.
  0.        ]
Mixed:


  frame = torch.tensor(batch[:,:,self.frame_position:frame_end],device=device)
  audio[batch_i] = torch.tensor(audio[batch_i][:,:,:end])


Clean:


PESQ: 1.022855520248413
STOI: 0.6661222610350591
[-0.00997925 -0.00952148 -0.00909424 ...  0.          0.
  0.        ]
Mixed:


Clean:


PESQ: 1.4511520862579346
STOI: 0.7195535911449249
[-0.15301514 -0.14761353 -0.14273071 ...  0.          0.
  0.        ]
Mixed:


Clean:


PESQ: 1.3806889057159424
STOI: 0.9767853649766188
[-0.00946045  0.00546265 -0.0173645  ...  0.          0.
  0.        ]
Mixed:


Clean:


PESQ: 1.2175393104553223
STOI: 0.7781692393146383


In [None]:
_flit = iter(FrameLoader(base_train_dataloader,CRN_FRAME_SIZE,FRAME_SHIFT))
def _get_next():
    return next(_flit)

mix, clean = stitch_audio(_get_next,CRN_FRAME_SIZE,FRAME_SHIFT)
for i in range(BATCH_SIZE):
    _mix = mix[i][0].numpy()
    _clean = clean[i][0].numpy()
    if _clean.max()==0:
        continue
    print(f"PESQ: {calc_pesq(_clean, _mix)}")
    print(f"STOI: {calc_stoi(_clean, _mix)}")

    print("Mixed:")
    display.display(Audio(_mix,rate=16000))
    print("Clean:")
    display.display(Audio(_clean,rate=16000))

## RHR-Net

In [None]:
import yaml
from models.rhrnetdir.Arg_Parser import Recursive_Parse
from models.rhrnet import RHRNet
hp = Recursive_Parse(yaml.load(
    open('models/rhrnetdir/rhrnet_hyperparameters.yaml', encoding='utf-8'),
    Loader=yaml.Loader
    ))  
net = RHRNet(hp)
net(torch.randn(3, 1, 1024))

def rhrnet_train(model):


## Wave-U-Net

In [None]:
from models.waveunet import Waveunet
waveunet_model = Waveunet()

def waveunet_train(model):
    