In [10]:
#!conda env update --file environment.yml --prune

In [1]:
import os, random, glob, logging, ntpath, math, time, sys, datetime, json, traceback
from IPython import display
from IPython.display import Audio
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# pd.options.display.max_seq_items = 2000
pd.set_option("display.max_colwidth",None)

logging.basicConfig()
logger=logging.getLogger("dbg")
logging.disable(logging.NOTSET)
logger.setLevel(logging.DEBUG)
perf_logger=logging.getLogger("perf")
perf_logger.setLevel(logging.DEBUG)
# logging.disable(logging.DEBUG)

import torch, torchaudio
import torch.nn as nn
import torchaudio.functional as audioF
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.debug(device)
logger.debug(torch.__version__)
logger.debug(torch.cuda.get_device_name(device))
logger.debug(torchaudio.list_audio_backends())

from ignite.engine import Engine, Events, EventEnum
from ignite.metrics import Loss, Metric, RunningAverage
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
from ignite.exceptions import NotComputableError
from ignite.handlers.tqdm_logger import ProgressBar
from ignite.handlers import Checkpoint, DiskSaver

DEBUG:dbg:cuda:0
DEBUG:dbg:2.6.0+cu124
DEBUG:dbg:NVIDIA RTX A4000
DEBUG:dbg:['ffmpeg']


In [2]:
BATCH_SIZE = 8
SHUFFLE = True
FRAME_SHIFT = 40
MAXIMUM_SAMPLE_NUM_OF_FRAMES = 640000

RUN_GAN = False
RUN_CRN = True
RUN_RNN = False
RUN_CNN = False

# Utility

In [3]:
from IPython.core.magic import register_cell_magic
from IPython import get_ipython

@register_cell_magic
def skip(line, cell):
    return

@register_cell_magic
def skip_if(line, cell):
    if eval(line):
        return
    get_ipython().run_cell(cell)

In [4]:
import pystoi
import pesq

def combine_audio(speech: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor | int) -> torch.Tensor:
    if not (torch.is_floating_point(speech) or torch.is_complex(speech)):
        # speech = torch.tensor(speech, dtype=torch.float64, device=speech.device)
        speech = speech.to(torch.float64,non_blocking=True)
    if not (torch.is_floating_point(noise) or torch.is_complex(noise)):
        # noise = torch.tensor(noise, dtype=torch.float64, device=noise.device)
        noise = noise.to(torch.float64,non_blocking=True)
    if not(type(snr) is torch.Tensor):
        snr = torch.tensor([snr])
    logger.debug(f"speech:{speech.ndim}, noise:{noise.ndim}, snr:{snr.ndim}")
    out = audioF.add_noise(speech, noise, snr).to(dtype=torch.float)
    return out

def calc_pesq(speech: np.ndarray, processed: np.ndarray) -> float:
    return pesq.pesq(ref=speech, deg=processed, fs=16000)

def calc_stoi(speech: np.ndarray, processed: np.ndarray) -> float:
    return pystoi.stoi(x=speech, y=processed, fs_sig=16000)

def ns_to_sec(ns: int):
    return ns/1000000000.0

def datetime_string():
    return datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")

def plot_waveform(waveform, sample_rate=16000):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")

def write_fstring_file(model_name: str, format_string: str, **args):
    with open(f"saved_models/{model_name}_{datetime_string()}.txt") as f:
        f.write(format_string.format(**args))

def standardize_batch(batch: torch.Tensor, coerce_func = lambda x: x):
    b = batch.squeeze()
    if len(b.shape) == 2:
        return b
    elif len(b.shape) == 1:
        return b.unsqueeze(dim=0)
    else:
        return coerce_func(b)



# def stitch_audio(frames_func, frame_size, frame_shift):
#     *batches, at_end = frames_func()

#     # audio[mix, clean]
#     audio = [torch.zeros((b.shape[0],b.shape[1],640000),dtype=torch.float) for b in batches if len(b.shape)==3]
#     pos=0
#     end=frame_size
#     for batch_i, batch in enumerate(batches):
#         audio[batch_i][:,:,:end] = batch[:,:,:]

#     frame_slice_start = frame_size - frame_shift
#     while not at_end:
#         pos = end
#         end += frame_shift
#         *batches, at_end = frames_func()
#         for batch_i, batch in enumerate((batches)):
#             audio[batch_i][:,0,pos:end] = batch[:,0,frame_slice_start:]

#             if at_end:
#                 print(audio[batch_i].shape)
#                 audio[batch_i] = torch.tensor(audio[batch_i][:,:,:end]) 
        
#     for i in range(BATCH_SIZE):
#         print(audio[0][i][0])

#     return audio

## Dataset

In [5]:
from torch.utils.data import Dataset, DataLoader

# ldrnd = random.Random(42)   #   Used for noise loading 

def get_sequential_wav_paths(dir):
    count = len(glob.glob("*.wav", root_dir=dir))
    lst = []
    for i in range(1,count+1):
        lst.append(dir + "/" + str(i) + ".wav")
    
    return lst

class AudioDataset(Dataset):
    def __init__(self, data: list, root_dir: str | None = None):
        self.data = data
        if root_dir==None:
            root_dir = os.getcwd()+"/data"
        self.root_dir = root_dir
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave

    def get(self,idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave
    
class SortedBatchDataset(Dataset):
    def __init__(self, mixed: list, clean: list, batch_size: int):
        self.mixed = mixed
        self.clean = clean
        self.batch_size = batch_size
        # if root_dir==None:
        #     root_dir = os.getcwd()+"/data"
        # self.root_dir = root_dir

    def __len__(self):
        return math.floor(len(self.mixed)/self.batch_size)
    
    def __getitem__(self,idx):
        mixeds = []
        cleans = []
        max_len = 0
        for i in range(self.batch_size):
            if idx + i > len(self.mixed): continue
            mixed_wave, _ = torchaudio.load(self.mixed[idx*self.batch_size+i])
            clean_wave, _ = torchaudio.load(self.clean[idx*self.batch_size+i])
            mixed_wave = mixed_wave[0]
            clean_wave = clean_wave[0]
            assert(mixed_wave.shape[0]==clean_wave.shape[0])
            if i==0:
                max_len = mixed_wave.shape[0]
            else:
                mixed_wave = torch.nn.functional.pad(mixed_wave,(0,max_len-(mixed_wave.shape[0])),value=0.0)
                clean_wave = torch.nn.functional.pad(clean_wave,(0,max_len-(clean_wave.shape[0])),value=0.0)
            # logger.debug(mixed_wave)
            mixeds.append(mixed_wave)
            cleans.append(clean_wave)
        mixeds = np.asarray(mixeds)
        cleans = np.asarray(cleans)
        return torch.tensor(mixeds), torch.tensor(cleans)

    def split(self, val_pct, seed=None):
        rnd = random.Random()
        if seed is not None:
            rnd.seed(seed)
        this_len = len(self)
        val_batches = math.floor(this_len*val_pct)
        val_indices = sorted(rnd.sample(range(this_len), val_batches))
        train_mixed = []
        train_clean = []
        val_mixed = []
        val_clean = []
        for i in range(this_len):
            if i in val_indices:
                for x in range(self.batch_size):
                    val_mixed.append(self.mixed[i*self.batch_size+x])
                    val_clean.append(self.clean[i*self.batch_size+x])
            else:
                for x in range(self.batch_size):
                    train_mixed.append(self.mixed[i*self.batch_size+x])
                    train_clean.append(self.clean[i*self.batch_size+x])
        
        return SortedBatchDataset(train_mixed,train_clean,self.batch_size), SortedBatchDataset(val_mixed,val_clean,self.batch_size)

class FrameLoaderEvents(EventEnum):
    END_OF_BATCH = "end_of_batch"

class FrameLoader():
    '''Takes a dataloader, frame size and frame shift. It can then be iterated over to produce frames.\n
    Provides padding when sample length would be exceeded.\n
    Returns (mix, clean, has_batch_ended)'''

    def __init__(self, dl: DataLoader, frame_size: int, frame_shift: int, batch_size: int, engine: Engine | None = None, output_transform = lambda x: x):
        self.dl = dl
        self.dl_iter = iter(dl)
        self.batch_count = len(dl)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.batch_size = batch_size
        self.batch_mixed: torch.Tensor
        self.batch_clean: torch.Tensor
        self.frame_position = 0
        self.at_end = True
        self.engine = engine
        self.output_transform = output_transform
    def __iter__(self):
        self.dl_iter = iter(self.dl)
        return self
    def __next__(self):
        if self.at_end:
            batches: tuple[torch.Tensor,torch.Tensor] = next(self.dl_iter)
            self.batch_mixed = batches[0].squeeze_()
            if len(self.batch_mixed.shape) == 1:
                self.batch_mixed.unsqueeze_(0)
            self.batch_clean = batches[1].squeeze_()
            if len(self.batch_clean.shape) == 1:
                self.batch_clean.unsqueeze_(0)
            self.frame_position = 0
            self.at_end = False
            # logger.debug(f"mixed batch shape:{self.batch_mixed.shape} | clean batch shape:{self.batch_clean.shape}")
            # mix_maxes = [torch.max(s[0]) for s in self.batch_mixed]
            # clean_maxes = [torch.max(s[0]) for s in self.batch_clean]
            # logger.debug(f"mix_maxes:{str(mix_maxes)} | clean_maxes:{str(clean_maxes)}")
        
        frame_end = self.frame_position + self.frame_size
        frames = []
        for batch_i, batch in enumerate([self.batch_mixed, self.batch_clean]):
            shp = batch.shape
            frame: torch.Tensor
            if frame_end >= shp[-1]:
                if self.engine is not None and batch_i==0: 
                    self.engine.fire_event(FrameLoaderEvents.END_OF_BATCH)
                self.at_end = True
                if frame_end != shp[-1]:
                    diff = frame_end - shp[-1]
                    # Pad batch until aligned with frame_end
                    frame = torch.zeros((self.batch_size, self.frame_size), dtype=torch.float32)
                    frame[:, 0:self.frame_size - diff] = batch[:, self.frame_position:shp[-1]]
                else:
                    frame = torch.zeros((self.batch_size, self.frame_size), dtype=torch.float32)
                    frame[:, 0:self.frame_size] = batch[:, self.frame_position:frame_end]
            else:
                frame = torch.zeros((self.batch_size, self.frame_size), dtype=torch.float32)
                frame[:, 0:self.frame_size] = batch[:, self.frame_position:frame_end]
            frames.append(self.output_transform(frame))

        self.frame_position += self.frame_shift
        # logger.debug(f"FrameLoader: frame_position[{self.frame_position}] - frame_end[{frame_end}]")
        # perf_logger.debug(f"Time to load frame at ({self.frame_position}): {time.perf_counter() - start}s")

        try:

            if frames[0].shape[-1] != self.frame_size:
                logger.debug(frames[0].shape[-1])
                logger.debug("FrameLoader issue")
        except Exception as e:
            print(frames[0].shape)
            raise Exception(e.args)
        
        return frames[0], frames[1], self.at_end

class FrameReconstructor():
    '''Constructs a batch of audio samples by continuously adding (batches of) frames to the end of a buffer, excluding overlapping sections.\n
    Use `add_frame()` to return the constructed samples, up to the last batch of frames added.
    '''
    def __init__(self, frame_size: int, frame_shift: int, batch_size: int, output_transform = lambda x: x):
        self.audio: torch.Tensor = torch.zeros((batch_size, MAXIMUM_SAMPLE_NUM_OF_FRAMES),dtype=torch.float)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.pos: int = 0
        self.end: int = frame_size
        self.frame_slice_start: int = 0
        self.at_end = False
        self.output_transform = output_transform
    
    def add_frame(self, batch: torch.Tensor, _at_end = False):
        batch = standardize_batch(batch)
        batch = batch.reshape((self.audio.shape[0], batch.shape[-1]))
        self.audio[:,self.pos:self.end] = batch[:,self.frame_slice_start:]

        self.pos = self.end
        self.end += self.frame_shift
        self.frame_slice_start = self.frame_size - self.frame_shift

        # if _at_end: 
        #     return True
        # return False

        #   Remove if _at_end ends up doing something
        return _at_end
    
    def get_current_audio(self):
        out = torch.tensor(self.audio[:,0:self.end - self.frame_shift])
        out = self.output_transform(out)
        return out

    def reset(self):
        self.audio = torch.zeros(self.audio.shape,dtype=torch.float)
        self.pos = 0
        self.end = self.frame_size
        self.frame_slice_start = 0
        self.at_end = False


def get_reference_batch(ds: Dataset, frame_size: int, output_transform=lambda x: x):
    while True:
        idx = random.randint(0,len(ds)-1)
        batches = ds.__getitem__(idx)
        batch, batch2 = batches[0], batches[1]
        print(batch.shape)
        if batch.shape[-1] < frame_size:
            continue
        randpos = random.randint(0, batch.shape[-1]-frame_size)
        batch = batch[:, randpos:randpos+frame_size]
        batch2 = batch2[:, randpos:randpos+frame_size]
        return output_transform(batch), output_transform(batch2)

# Models

In [6]:
# criterion = nn.MSELoss()

pf_train_totals = [0,0]                                                     ###
pf_train_num_loops = 0                                                      ###
pf_eval_total = 0
pf_eval_num_loops = 0

class PESQMetric(Metric):
    def __init__(self, stitch_keys=("stitch_proc","stitch_clean"), output_transform = lambda x: x, device=device):
        self.stitch_keys=stitch_keys
        self.running_total=0.0
        self.num=0
        super().__init__(output_transform, device)
    @reinit__is_reduced
    def reset(self):
        self.running_total=0.0
        self.num=0
        super().reset()
    @reinit__is_reduced
    def update(self, output):
        if len(output)<=2 or "stitch_proc" not in output[2]: return
        y_pred: np.ndarray = standardize_batch(output[2][self.stitch_keys[0]]).cpu().numpy()
        y: np.ndarray = standardize_batch(output[2][self.stitch_keys[1]]).cpu().numpy()
        for i in range(y.shape[0]):
            self.running_total += calc_pesq(y[i], y_pred[i])
            self.num += 1
        
    @sync_all_reduce("num","running_total:SUM")
    def compute(self):
        if self.num == 0:
            raise NotComputableError("PESQ Metric must have one complete sample before computing")
        return self.running_total / self.num

class STOIMetric(Metric):
    def __init__(self, stitch_keys=("stitch_proc","stitch_clean"), output_transform = lambda x: x, device=device):
        self.stitch_keys=stitch_keys
        self.running_total=0.0
        self.num=0
        super().__init__(output_transform, device)
    @reinit__is_reduced
    def reset(self):
        self.running_total=0.0
        self.num=0
        super().reset()
    @reinit__is_reduced
    def update(self, output):
        if len(output)<=2 or "stitch_proc" not in output[2]: return
        y_pred: np.ndarray = standardize_batch(output[2][self.stitch_keys[0]]).cpu().numpy()
        y: np.ndarray = standardize_batch(output[2][self.stitch_keys[1]]).cpu().numpy()
        for i in range(y.shape[0]):
            self.running_total += calc_stoi(y[i], y_pred[i])
            self.num += 1
    @sync_all_reduce("num","running_total:SUM")
    def compute(self):
        if self.num == 0:
            raise NotComputableError("STOI Metric must have one complete sample before computing")
        return self.running_total / self.num


class ValidationEvents(EventEnum):
    VALIDATION_COMPLETED = "validation_completed"

def register_custom_events(eng: Engine):
    eng.register_events(*FrameLoaderEvents)
    eng.register_events(*ValidationEvents)

def log_trainer_loss(eng: Engine):
    iterations = eng.state.iteration % eng.state.iteration_ceiling
    print(f"Epoch[{eng.state.epoch}], Iter[{iterations}] Loss: {eng.state.output}")

def log_custom(eng: Engine, **kwargs):
    full_dict = {**eng.state_dict(), "epoch": eng.state.epoch, **kwargs}
    fmt_string: str = kwargs["template"]
    print(fmt_string.format(**full_dict))

def run_eval(eng: Engine, **kwargs):
    validator: Engine = kwargs.get("validator",None)
    val_frame_loader: FrameLoader = kwargs.get("val_frame_loader",None)
    if validator is None:
        raise TypeError("log_eval_results must be passed the argument `validator` of type `Engine`")
    if val_frame_loader is None:
        raise TypeError("log_eval_results must be passed the argument `val_frame_loader` of type `FrameLoader`")
    
    validator.run(val_frame_loader)
    eng.fire_event(ValidationEvents.VALIDATION_COMPLETED)

def log_eval_results(eng: Engine, **kwargs):
    prefix = kwargs.get("prefix","")
    validator: Engine = kwargs.get("validator",None)
    if validator is None:
        raise TypeError("log_eval_results must be passed the argument `validator` of type `Engine`")
    
    metrics = validator.state.metrics
    metrics_out = kwargs.get("metrics_out",None)
    if metrics_out != None:
        metrics_out.append(metrics.copy())
    print(f"{prefix}Epoch[{eng.state.epoch}] | PESQ:[{metrics['pesq']:.2f}] | STOI:[{metrics['stoi']:.2f}] | Loss:[{metrics['loss']}]")

def set_engine_custom_keys(eng: Engine):
    eng.state_dict_user_keys.append("iteration_ceiling")
    eng.state.iteration_ceiling = sys.maxsize

def set_iteration_ceiling(eng: Engine, *args):
    if len(args)==1:
        eng.state.iteration_ceiling = args[0]
    else:
        eng.state.iteration_ceiling = eng.state.iteration

## SEGAN

In [7]:
gan_hp = {
    "frame_size":16384,
    "frame_shift":8192,
    "g_lr":1.0e-4,
    "d_lr":1.0e-4,
    "batch_size":128,
    "epochs":80,
    "save":True,
    "load":None,
}

GAN_RUN_ON_LOAD = False

In [8]:
def gan(hp: dict = gan_hp):
    global pf_train_totals, pf_train_num_loops, pf_eval_total, pf_eval_num_loops
    pf_train_totals = [0,0]
    pf_train_num_loops = 0
    pf_eval_total = 0
    pf_eval_num_loops = 0
    try:
        datestring_at_start = datetime_string()
        os.mkdir(f"saved_models/gan_{datestring_at_start}")

        from models.segan import Discriminator, Generator

        torch.cuda.empty_cache()
        out = {"hp": hp}
        gen = Generator().to(device=device)
        dcrim = Discriminator().to(device=device)

        if hp["load"] != None:
            gen.load_state_dict(torch.load(hp["load"][0], weights_only=True))
            dcrim.load_state_dict(torch.load(hp["load"][1], weights_only=True))
        
        g_optimizer = torch.optim.RMSprop(gen.parameters(), lr=hp["g_lr"])
        d_optimizer = torch.optim.RMSprop(dcrim.parameters(), lr=hp["d_lr"])
        criterion = nn.L1Loss()
        out["optimizer"] = "RMSprop"
        out["criterion"] = "L1loss"


        _dataset = SortedBatchDataset(get_sequential_wav_paths("data/mixed/train"),
                                    get_sequential_wav_paths("data/speech_ordered/train"), 
                                    batch_size=hp["batch_size"])
        train_dataset, val_dataset = _dataset.split(0.2)
        del _dataset
        base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)
        base_val_dataloader = DataLoader(val_dataset)
        r = get_reference_batch(train_dataset, hp["frame_size"], lambda x: x.view(hp["batch_size"],1,-1))
        ref_batch = torch.cat((r[0],r[1]),dim=1).to(device=device)
        z = torch.zeros((hp["batch_size"],1024,8)).to(device=device)
        print(ref_batch.shape)

        def train_step(engine, batch):
            # global pf_train_totals, pf_train_num_loops
            # pf_train_forward = time.perf_counter_ns()                               ###
            dcrim.train()
            dcrim.zero_grad()
            x, y = batch[0].to(device=device), batch[1].to(device=device)
            nn.init.normal_(z)
            combined_batch = torch.cat((x.clone().detach(),y.clone().detach()),dim=1)
            output = dcrim(combined_batch, ref_batch)
            clean_loss = torch.mean((output - 1.0) ** 2)
            clean_loss.backward()

            gen_out = gen(x, z)
            output = dcrim(torch.cat((gen_out, x), dim=1),ref_batch)
            noisy_loss = torch.mean(output ** 2)
            noisy_loss.backward()

            d_optimizer.step()

            gen.train()
            gen.zero_grad()
            gen_out = gen(x, z)
            gen_noise_pair = torch.cat((gen_out, x), dim=1)
            output = dcrim(gen_noise_pair, ref_batch)

            g_loss_ = 0.5 * torch.mean((output - 1.0) ** 2)
            l1_dist = torch.abs(torch.add(gen_out, torch.neg(y)))
            g_cond_loss = 100 * torch.mean(l1_dist)
            g_loss = g_loss_ + g_cond_loss
            g_loss.backward()
            g_optimizer.step()

            return g_loss.item()
            
        trainer = Engine(train_step)
        register_custom_events(trainer)
        RunningAverage(output_transform=lambda x: x).attach(trainer,'loss')
        pbar = ProgressBar(desc="Training Epoch")
        pbar.attach(trainer,['loss'])

        trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)

        train_dataloader = FrameLoader(base_train_dataloader, hp["frame_size"], hp["frame_shift"],
                                        hp["batch_size"], engine=trainer, 
                                        output_transform=lambda x: x.view((hp["batch_size"],1,-1)))

        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"],
                                                    output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"],
                                                     output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        
        def val_step(engine, batch):
            with torch.no_grad():
                # global pf_eval_total, pf_eval_num_loops
                x, y = batch[0].to(device=device), batch[1].to(device=device)
                nn.init.normal_(z)
                y_pred = gen(x, z)
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y
        
        validator = Engine(val_step)
        pbar = ProgressBar(desc="Validation")
        pbar.attach(validator,['loss'])
        val_metrics: dict[str, Metric] = {
            "loss": Loss(criterion, output_transform=lambda x: (x[0],x[1])),
            "pesq": PESQMetric(),
            "stoi": STOIMetric()
        }
        for name, metric in val_metrics.items():
            metric.attach(validator, name)

        checkpoint_to_save = {"gen": gen, "dcrim": dcrim}
        checkpoint_handler = Checkpoint(
            checkpoint_to_save, f"saved_models/gan_{datestring_at_start}",
            filename_prefix="best", score_function=lambda eng: eng.state.metrics['pesq'],n_saved=2
        )

        metrics_out = []
        out["metrics"] = metrics_out
        val_dataloader = FrameLoader(base_val_dataloader, hp["frame_size"], hp["frame_shift"],
                                     hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        trainer.add_event_handler(Events.EPOCH_COMPLETED,run_eval,validator=validator,val_frame_loader=val_dataloader)
        trainer.add_event_handler(ValidationEvents.VALIDATION_COMPLETED,log_eval_results,validator=validator,metrics_out=metrics_out)
        validator.add_event_handler(Events.COMPLETED, checkpoint_handler)
            
        time_train = time.perf_counter_ns()
        trainer.run(train_dataloader, max_epochs=hp["epochs"])
        out["total_time"] = time.perf_counter_ns() - time_train
        # out["fwd"]=pf_train_totals[0] / float(pf_train_num_loops)                        ###
        # out["bck"]=pf_train_totals[1] / float(pf_train_num_loops)                        ###
        # out["eval"]=pf_eval_total / float(pf_eval_num_loops)                          ###


    except Exception as e:
        print("failed")
        print(traceback.format_exc())
        print(e)
    
    finally:
        if "dcrim" in locals():
            if hp["save"]:
                torch.save(gen.state_dict(),f"saved_models/gan_{datestring_at_start}/gen_final.pt")
                torch.save(dcrim.state_dict(),f"saved_models/gan_{datestring_at_start}/dcrim_final.pt")
                with open(f"saved_models/gan_{datestring_at_start}/out.json","w") as f:
                    json.dump({k: out[k] for k in out.keys() - {'gen', 'dcrim'}},f)
            return out
        else:
            return None

if GAN_RUN_ON_LOAD:
    gan()


## WaveCRN

In [9]:
crn_hp = {
    "frame_size":96,
    "frame_shift":40,
    "lr":5.0e-7,
    "batch_size":128,
    "epochs":80,
    "save":True,
    "load":None,
}

CRN_RUN_ON_LOAD = True

In [11]:
def crn(hp: dict = crn_hp):
    global pf_train_totals, pf_train_num_loops, pf_eval_total, pf_eval_num_loops
    pf_train_totals = [0,0]                                                     ###
    pf_train_num_loops = 0                                                      ###
    pf_eval_total = 0
    pf_eval_num_loops = 0
    try:
        datestring_at_start = datetime_string()
        os.mkdir(f"saved_models/crn_{datestring_at_start}")

        from models.wavecrn import ConvBSRU

        torch.cuda.empty_cache()
        out = {"hp": hp}
        model = ConvBSRU(frame_size=hp["frame_size"], conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
        if hp["load"] != None:
            model.load_state_dict(torch.load(hp["load"], weights_only=True))
        out["model"] = model
        
        # optimizer = torch.optim.Adam(model.parameters(),lr=hp["lr"])
        optimizer = torch.optim.Adam(model.parameters(),lr=hp["lr"])
        criterion = nn.L1Loss()
        out["optimizer"] = "Adam"
        out["criterion"] = "L1Loss"

        _dataset = SortedBatchDataset(get_sequential_wav_paths("data/mixed/train"), 
                                      get_sequential_wav_paths("data/speech_ordered/train"), 
                                      batch_size=hp["batch_size"])
        train_dataset, val_dataset = _dataset.split(0.2)
        del _dataset
        base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)
        base_val_dataloader = DataLoader(val_dataset)

        def train_step(engine, batch):
            global pf_train_totals, pf_train_num_loops
            pf_train_forward = time.perf_counter_ns()                               ###
            model.train()
            optimizer.zero_grad()
            x, y = batch[0].to(device), batch[1].to(device)
            y_pred = model(x)
            pf_train_totals[0] += (time.perf_counter_ns() - pf_train_forward)       ###
            pf_train_back = time.perf_counter_ns()                                  ###
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            pf_train_totals[1] += (time.perf_counter_ns() - pf_train_back)          ###
            pf_train_num_loops += 1                                                 ###
            return loss.item()

        trainer = Engine(train_step)
        register_custom_events(trainer)
        RunningAverage(output_transform=lambda x: x).attach(trainer,'loss')
        pbar = ProgressBar(desc="Training Epoch")
        pbar.attach(trainer,['loss'])

        trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)

        train_dataloader = FrameLoader(base_train_dataloader, hp["frame_size"], hp["frame_shift"], batch_size=hp["batch_size"], engine=trainer, output_transform=lambda x: x.view((hp["batch_size"],1,-1)))

        
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        def val_step(engine, batch):
            global pf_eval_total, pf_eval_num_loops
            pf_eval_forward = time.perf_counter_ns()                                ###
            model.eval()
            with torch.no_grad():
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)         ###
                pf_eval_num_loops += 1                                              ###
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y

        validator = Engine(val_step)
        pbar = ProgressBar(desc="Validation")
        pbar.attach(validator,['loss'])
        val_metrics: dict[str, Metric] = {
            "loss": Loss(criterion, output_transform=lambda x: (x[0],x[1])),
            "pesq": PESQMetric(),
            "stoi": STOIMetric()
        }
        for name, metric in val_metrics.items():
            metric.attach(validator, name)

        checkpoint_to_save = {"model":model}
        checkpoint_handler = Checkpoint(
            checkpoint_to_save, f"saved_models/crn_{datestring_at_start}",
            filename_prefix="best", score_function=lambda eng: eng.state.metrics['pesq'],n_saved=2
        )

        metrics_out = []
        out["metrics"] = metrics_out
        val_dataloader = FrameLoader(base_val_dataloader, hp["frame_size"], hp["frame_shift"],
                                     hp["batch_size"],output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        trainer.add_event_handler(Events.EPOCH_COMPLETED,run_eval,validator=validator,val_frame_loader=val_dataloader)
        trainer.add_event_handler(ValidationEvents.VALIDATION_COMPLETED,log_eval_results,validator=validator,metrics_out=metrics_out)
        validator.add_event_handler(Events.COMPLETED, checkpoint_handler)


        time_train = time.perf_counter_ns()
        trainer.run(train_dataloader, max_epochs=hp["epochs"])
        out["total_time"] = time.perf_counter_ns() - time_train
        out["fwd"]=pf_train_totals[0] / float(pf_train_num_loops)                        ###
        out["bck"]=pf_train_totals[1] / float(pf_train_num_loops)                        ###
        out["eval"]=pf_eval_total / float(pf_eval_num_loops)                          ###
        
        
    except Exception as e:
        print(traceback.print_exc())
        print(e)    
    finally:
        if "model" in locals():
            if hp["save"]:
                torch.save(model.state_dict(),f"saved_models/crn_{datestring_at_start}/final.pt")
                json_dict = {k: out[k] for k in out.keys() - {'model'}}
                with open(f"saved_models/crn_{datestring_at_start}/out.json","w") as f:
                    json.dump(json_dict,f)
            return out
        else:
            return None

if CRN_RUN_ON_LOAD:
    crn()

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  from tqdm.autonotebook import tqdm


Training Epoch[1/?]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

  out = torch.tensor(self.audio[:,0:self.end - self.frame_shift])


Epoch[1] | PESQ:[1.26] | STOI:[0.80] | Loss:[0.017353107467738445]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[2] | PESQ:[1.33] | STOI:[0.82] | Loss:[0.015892756131358026]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[3] | PESQ:[1.36] | STOI:[0.83] | Loss:[0.015625199853163883]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[4] | PESQ:[1.31] | STOI:[0.81] | Loss:[0.014960630982347506]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[5] | PESQ:[1.38] | STOI:[0.83] | Loss:[0.014529285043692698]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[6] | PESQ:[1.42] | STOI:[0.83] | Loss:[0.014353434244791666]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[7] | PESQ:[1.41] | STOI:[0.84] | Loss:[0.014134841991748965]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[8] | PESQ:[1.43] | STOI:[0.84] | Loss:[0.013870264953219285]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[9] | PESQ:[1.46] | STOI:[0.84] | Loss:[0.013761930263168473]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[10] | PESQ:[1.45] | STOI:[0.84] | Loss:[0.01347442150160577]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[11] | PESQ:[1.44] | STOI:[0.84] | Loss:[0.013398528690143635]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[12] | PESQ:[1.48] | STOI:[0.84] | Loss:[0.01344897591127561]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[13] | PESQ:[1.48] | STOI:[0.84] | Loss:[0.01333168779948282]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[14] | PESQ:[1.49] | STOI:[0.85] | Loss:[0.013202272883210105]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[15] | PESQ:[1.52] | STOI:[0.85] | Loss:[0.013310089382557419]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[16] | PESQ:[1.48] | STOI:[0.84] | Loss:[0.013047203958307753]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[17] | PESQ:[1.49] | STOI:[0.85] | Loss:[0.013050251719056974]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[18] | PESQ:[1.50] | STOI:[0.85] | Loss:[0.013048447330491627]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[19] | PESQ:[1.49] | STOI:[0.85] | Loss:[0.012912628547839046]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[20] | PESQ:[1.51] | STOI:[0.85] | Loss:[0.012932790877029686]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[21] | PESQ:[1.53] | STOI:[0.85] | Loss:[0.012861280559946996]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[22] | PESQ:[1.52] | STOI:[0.85] | Loss:[0.012784434163388618]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[23] | PESQ:[1.56] | STOI:[0.85] | Loss:[0.012901880730189915]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[24] | PESQ:[1.54] | STOI:[0.85] | Loss:[0.01280889476312535]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[25] | PESQ:[1.57] | STOI:[0.85] | Loss:[0.01285472252112642]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[26] | PESQ:[1.56] | STOI:[0.85] | Loss:[0.012842171742434542]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[27] | PESQ:[1.57] | STOI:[0.86] | Loss:[0.012803904144118661]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[28] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.01266105766771038]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[29] | PESQ:[1.57] | STOI:[0.86] | Loss:[0.012729655838690769]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[30] | PESQ:[1.53] | STOI:[0.85] | Loss:[0.012664519410013361]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[31] | PESQ:[1.53] | STOI:[0.85] | Loss:[0.012567417895295398]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[32] | PESQ:[1.53] | STOI:[0.85] | Loss:[0.012515421812143325]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[33] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.012560797045480458]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[34] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.01253682751351784]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[35] | PESQ:[1.56] | STOI:[0.86] | Loss:[0.012487062648188265]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[36] | PESQ:[1.56] | STOI:[0.85] | Loss:[0.012439349132833696]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[37] | PESQ:[1.54] | STOI:[0.85] | Loss:[0.012426367242316869]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[38] | PESQ:[1.59] | STOI:[0.86] | Loss:[0.012530198098567306]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[39] | PESQ:[1.54] | STOI:[0.85] | Loss:[0.012380702221892102]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[40] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.012383120445175092]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[41] | PESQ:[1.59] | STOI:[0.86] | Loss:[0.012466801819938284]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[42] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.01239269198420251]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[43] | PESQ:[1.56] | STOI:[0.86] | Loss:[0.01234490852024058]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[44] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.01235281128534957]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[45] | PESQ:[1.59] | STOI:[0.86] | Loss:[0.012336417615820734]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[46] | PESQ:[1.56] | STOI:[0.85] | Loss:[0.012342872873014167]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[47] | PESQ:[1.60] | STOI:[0.86] | Loss:[0.012402083655382432]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[48] | PESQ:[1.57] | STOI:[0.86] | Loss:[0.012291440661288766]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[49] | PESQ:[1.57] | STOI:[0.86] | Loss:[0.012329195778991544]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[50] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.012268401874060805]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[51] | PESQ:[1.54] | STOI:[0.85] | Loss:[0.012269396857312424]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[52] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.012221165151254355]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[53] | PESQ:[1.57] | STOI:[0.86] | Loss:[0.012239050581899323]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[54] | PESQ:[1.57] | STOI:[0.85] | Loss:[0.012215446495722091]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[55] | PESQ:[1.54] | STOI:[0.85] | Loss:[0.012233958608788094]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[56] | PESQ:[1.57] | STOI:[0.86] | Loss:[0.012247782737638432]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[57] | PESQ:[1.58] | STOI:[0.86] | Loss:[0.01240974945173996]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[58] | PESQ:[1.57] | STOI:[0.85] | Loss:[0.012183210180387694]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[59] | PESQ:[1.57] | STOI:[0.86] | Loss:[0.012193512610985885]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[60] | PESQ:[1.58] | STOI:[0.86] | Loss:[0.012201416803617474]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[61] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.012141842002986424]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[62] | PESQ:[1.58] | STOI:[0.86] | Loss:[0.012149686239668847]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[63] | PESQ:[1.58] | STOI:[0.86] | Loss:[0.012154008778099126]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[64] | PESQ:[1.60] | STOI:[0.86] | Loss:[0.012238409624252298]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[65] | PESQ:[1.59] | STOI:[0.86] | Loss:[0.012173721437671028]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[66] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.012122166459002099]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[67] | PESQ:[1.55] | STOI:[0.85] | Loss:[0.012126685995551063]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[68] | PESQ:[1.61] | STOI:[0.86] | Loss:[0.01211211812742656]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[69] | PESQ:[1.57] | STOI:[0.86] | Loss:[0.012096188402742451]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[70] | PESQ:[1.56] | STOI:[0.85] | Loss:[0.01209876222598875]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[71] | PESQ:[1.59] | STOI:[0.86] | Loss:[0.012121321365623392]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[72] | PESQ:[1.58] | STOI:[0.86] | Loss:[0.012086108665669731]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[73] | PESQ:[1.59] | STOI:[0.86] | Loss:[0.01207609744968177]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[74] | PESQ:[1.58] | STOI:[0.85] | Loss:[0.01206534534946484]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[75] | PESQ:[1.56] | STOI:[0.86] | Loss:[0.012063608054461654]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[76] | PESQ:[1.59] | STOI:[0.86] | Loss:[0.012033708593622095]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[77] | PESQ:[1.59] | STOI:[0.86] | Loss:[0.012048293592017787]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[78] | PESQ:[1.56] | STOI:[0.86] | Loss:[0.012072423006511512]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[79] | PESQ:[1.58] | STOI:[0.86] | Loss:[0.012033027665342292]


Training Epoch[1/195756]   0%|           [00:00<?]

Validation[1/?]   0%|           [00:00<?]

Epoch[80] | PESQ:[1.58] | STOI:[0.86] | Loss:[0.012054206390623538]


In [None]:
# if 'crn_model' in locals(): torch.save(crn_model.state_dict(),f"saved_models/crn_{datetime_string()}.pt")

## RHR-Net

In [None]:
rnn_hp = {
    "frame_size":1024,
    "frame_shift":256,
    "lr":1.0e-4,
    "batch_size":128,
    "epochs":30,
    "save":True,
    "load":None,
}

RNN_RUN_ON_LOAD = False

In [None]:
def rnn(hp: dict = rnn_hp):
    global pf_train_totals, pf_train_num_loops, pf_eval_total, pf_eval_num_loops
    pf_train_totals = [0,0]                                                     ###
    pf_train_num_loops = 0                                                      ###
    pf_eval_total = 0
    pf_eval_num_loops = 0
    try:
        datestring_at_start = datetime_string()
        os.mkdir(f"saved_models/rnn_{datestring_at_start}")
        
        import yaml
        from models.rhrnetdir.Arg_Parser import Recursive_Parse
        from models.rhrnet import RHRNet
        rnn_hp = Recursive_Parse(yaml.load(
            open('models/rhrnetdir/rhrnet_hyperparameters.yaml', encoding='utf-8'),
            Loader=yaml.Loader
            ))  
        torch.cuda.empty_cache()

        out = {"hp": hp, "datetime_str": datestring_at_start}
        model = RHRNet(rnn_hp).to(device=device)
        if hp["load"] != None:
            model.load_state_dict(torch.load(hp["load"], weights_only=True))
        out["model"] = model
        
        optimizer = torch.optim.RMSprop(model.parameters(),lr=hp["lr"])
        out["model"] = "RMSProp"

        _dataset = SortedBatchDataset(get_sequential_wav_paths("data/mixed/train"), get_sequential_wav_paths("data/speech_ordered/train"), batch_size=hp["batch_size"])
        train_dataset, val_dataset = _dataset.split(0.2)
        del _dataset
        base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)
        base_val_dataloader = DataLoader(val_dataset)
        print(f"val dataset:{len(val_dataset)}")

        def train_step(engine, batch):
            global pf_train_totals, pf_train_num_loops
            pf_train_forward = time.perf_counter_ns()                               ###
            model.train()
            optimizer.zero_grad()
            x, y = batch[0].to(device), batch[1].to(device)
            y_pred = model(x)
            pf_train_totals[0] += (time.perf_counter_ns() - pf_train_forward)       ###
            pf_train_back = time.perf_counter_ns()                                  ###
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            pf_train_totals[1] += (time.perf_counter_ns() - pf_train_back)          ###
            pf_train_num_loops += 1                                                 ###
            return loss.item()

        trainer = Engine(train_step)
        register_custom_events(trainer)
        RunningAverage(output_transform=lambda x: x).attach(trainer,'loss')
        pbar = ProgressBar(desc="Training Epoch")
        pbar.attach(trainer,['loss'])

        trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)

        train_dataloader = FrameLoader(base_train_dataloader, hp["frame_size"], hp["frame_shift"], batch_size=hp["batch_size"], engine=trainer, output_transform=lambda x: x.view((hp["batch_size"],-1)))

        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],-1)))
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],-1)))
        def val_step(engine, batch):
            global pf_eval_total, pf_eval_num_loops
            pf_eval_forward = time.perf_counter_ns()                                ###
            model.eval()
            with torch.no_grad():
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)         ###
                pf_eval_num_loops += 1                                              ###
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y

        validator = Engine(val_step)
        pbar = ProgressBar(desc="Validation")
        pbar.attach(validator,['loss'])
        val_metrics: dict[str, Metric] = {
            "loss": Loss(criterion, output_transform=lambda x: (x[0],x[1])),
            "pesq": PESQMetric(),
            "stoi": STOIMetric()
        }

        for name, metric in val_metrics.items():
            metric.attach(validator, name)

        checkpoint_to_save = {"model":model}
        checkpoint_handler = Checkpoint(
            checkpoint_to_save, f"saved_models/rnn_{datestring_at_start}",
            filename_prefix="best", score_function=lambda eng: eng.state.metrics['pesq'],n_saved=2
        )

        metrics_out = []
        out["metrics"] = metrics_out
        val_dataloader = FrameLoader(base_val_dataloader, hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        trainer.add_event_handler(Events.EPOCH_COMPLETED,run_eval,validator=validator,val_frame_loader=val_dataloader)
        trainer.add_event_handler(ValidationEvents.VALIDATION_COMPLETED,log_eval_results,validator=validator,metrics_out=metrics_out)
        validator.add_event_handler(Events.COMPLETED, checkpoint_handler)

        time_train = time.perf_counter_ns()
        trainer.run(train_dataloader, max_epochs=hp["epochs"])
        out["total_time"] = time.perf_counter_ns() - time_train
        out["fwd"]=pf_train_totals[0] / float(pf_train_num_loops)                        ###
        out["bck"]=pf_train_totals[1] / float(pf_train_num_loops)                        ###
        out["eval"]=pf_eval_total / float(pf_eval_num_loops)                          ###
        
    except Exception as e:
        print(e)    
    finally:
        if "model" in locals():
            if hp["save"]:
                torch.save(model.state_dict(),f"saved_models/rnn_{datestring_at_start}/final.pt")
                with open(f"saved_models/rnn_{datestring_at_start}/out.json","w") as f:
                    json.dump({k: out[k] for k in out.keys() - {'model'}},f)
            return out
        else:
            return None

if RNN_RUN_ON_LOAD:
    rnn()

In [None]:
a=ns_to_sec(pf_train_totals[0] / float(pf_train_num_loops))                         ###
b=ns_to_sec(pf_train_totals[1] / float(pf_train_num_loops))                         ###
c=ns_to_sec(pf_eval_total / float(pf_train_num_loops/4))                            ###
print(f"train forward:{a:.20f} | backprop:{b:.20f} | eval forward:{c:.20f}")

In [None]:
# if 'rnn_model' in locals(): torch.save(rnn_model.state_dict(),f"saved_models/rnn_{datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")}.pt")

## Wave-U-Net

In [None]:
CNN_FRAME_SIZE = 16153
CNN_OUT_FRAME_SIZE = 16009
CNN_FRAME_SHIFT = CNN_FRAME_SIZE / 4
CNN_LR = 1.0e-4

CNN_PREP = True
CNN_TRAIN = False
CNN_LOAD= False
CNN_SAVE = False

In [None]:
# if 'cnn_model' in locals(): torch.save(cnn_model.state_dict(),f"saved_models/cnn_{datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")}.pt")

# Running Models 

## Model Running Functions 

In [None]:
class FakeDataset(Dataset):
    def __init__(self, l_waveforms: list, r_waveforms:list):
        self.l_waveforms = l_waveforms
        self.r_waveforms = r_waveforms
        super().__init__()
    def __len__(self):
        return len(self.l_waveforms)
    def __getitem__(self,idx):
        # return self.l_waveforms[idx], self.r_waveforms[idx]
        # print(torch.tensor(self.l_waveforms[idx]).unsqueeze_(0))
        # print(torch.tensor(self.l_waveforms[idx]).unsqueeze_(0).shape,flush=True)

        return torch.tensor(self.l_waveforms[idx]).unsqueeze_(0), torch.tensor(self.r_waveforms[idx]).unsqueeze_(0)

def evaluate_e2e_one_sample(model_dict: dict):
    chosen_sample = ntpath.basename(random.choice(glob.glob("data/speech_ordered/train/*.wav")))
    clean_sample,_ = torchaudio.load("data/speech_ordered/train/" + chosen_sample)
    mixed_sample,_ = torchaudio.load("data/mixed/train/" + chosen_sample)
    ds = FakeDataset([mixed_sample],[clean_sample])
    dl = DataLoader(ds)
    out = {}
    for name in model_dict.keys():
        out[name] = {}
        out[name]["model"] = model_dict[name]
        match name:
            case "cnn":
                out[name]["loader"] = FrameLoader(dl,CNN_FRAME_SIZE,CNN_FRAME_SHIFT, batch_size=1)
                out[name]["processed_constructor"] = FrameReconstructor(CNN_OUT_FRAME_SIZE,CNN_FRAME_SHIFT, batch_size=1)
                out[name]["clean_constructor"] = FrameReconstructor(CNN_OUT_FRAME_SIZE,CNN_FRAME_SHIFT, batch_size=1)
            case "rnn":
                out[name]["loader"] = FrameLoader(dl,RNN_FRAME_SIZE,RNN_FRAME_SHIFT, batch_size=1)
                out[name]["processed_constructor"] = FrameReconstructor(RNN_FRAME_SIZE,RNN_FRAME_SHIFT, batch_size=1)
                out[name]["clean_constructor"] = FrameReconstructor(RNN_FRAME_SIZE,RNN_FRAME_SHIFT, batch_size=1)
            case "crn":
                out[name]["loader"] = FrameLoader(dl,CRN_FRAME_SIZE,CRN_FRAME_SHIFT, batch_size=1, output_transform=lambda x: x.reshape(1,1,-1))
                out[name]["processed_constructor"] = FrameReconstructor(CRN_FRAME_SIZE,CRN_FRAME_SHIFT, batch_size=1)
                out[name]["clean_constructor"] = FrameReconstructor(CRN_FRAME_SIZE,CRN_FRAME_SHIFT, batch_size=1)
            # case "gan":
            #     out[name]["loader"] = FrameLoader(dl,GAN_FRAME_SIZE,GAN_FRAME_SHIFT)
            case _:
                pass
        out[name]["perf"] = {"e2e_time":0, "avg_forward":0}
    # print("shape__:" + str(clean_sample.shape))
    out["num_frames"] = clean_sample.shape[1]
    
    for name in model_dict.keys():
        # print(name)
        # print(out[name],flush=True)
        at_end = False
        model = out[name]["model"]
        data = iter(out[name]["loader"])
        pf_eval_total = 0
        n_loops = 0
        pf_eval_e2e = time.perf_counter_ns()
        while not at_end:
            pf_eval_forward = time.perf_counter_ns()
            model.eval()
            with torch.no_grad():
                x, y, at_end = next(data)
                x = x.to(device=device)
                y = y.to(device=device)
                y_pred = model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)
                n_loops += 1
                try:
                    out[name]["processed_constructor"].add_frame(y_pred)
                    out[name]["clean_constructor"].add_frame(y)
                except RuntimeError as e:
                    print(out[name]["processed_constructor"].audio.shape)
                    raise RuntimeError(e.args)
                

                if at_end:    #   Frame fully constructed
                    y_pred_stitch = out[name]["processed_constructor"].get_current_audio()
                    y_stitch = out[name]["clean_constructor"].get_current_audio()
                    print(y_pred_stitch.shape)
                    print(y_stitch.shape)
                    display.display(Audio(y_pred_stitch[0][0],rate=16000))
                    display.display(Audio(y_stitch[0][0],rate=16000))
                    
        
        pf_eval_e2e = time.perf_counter_ns() - pf_eval_e2e
        out[name]["perf"]["e2e_time"] = pf_eval_e2e
        out[name]["perf"]["avg_forward"] = pf_eval_total / float(n_loops)
    
    return out

In [None]:
crn_model = ConvBSRU(frame_size=CRN_FRAME_SIZE, conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
crn_model.load_state_dict(torch.load("saved_models/crn.pt", weights_only=True))
models = {  "crn": crn_model }#, "rnn": rnn_model}#, "cnn": cnn_model,}
out = evaluate_e2e_one_sample(models)

print("Duration of audio: " + str(out["num_frames"] / 16000.0))
for name in models.keys():
    model_dict = out[name]
    print(f"{name.upper()} -- e2e:{ns_to_sec(model_dict["perf"]["e2e_time"])}, average forward:{ns_to_sec(model_dict["perf"]["avg_forward"])}")