In [None]:
#!conda env update --file environment.yml --prune

# Initialisation

In [2]:
import os, random, glob, logging, ntpath, math, time, sys, datetime, json, traceback
from typing import Callable
from IPython import display
from IPython.display import Audio
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# pd.options.display.max_seq_items = 2000
pd.set_option("display.max_colwidth",None)

logging.basicConfig()
logger=logging.getLogger("dbg")
logger.setLevel(logging.DEBUG)
logging.disable(logging.NOTSET)
perf_logger=logging.getLogger("perf")
perf_logger.setLevel(logging.DEBUG)
# logging.disable(logging.DEBUG)

import torch, torchaudio
import torch.nn as nn
import torchaudio.functional as audioF
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.debug(device)
logger.debug(torch.__version__)
logger.debug(torch.cuda.get_device_name(device))
logger.debug(torchaudio.list_audio_backends())

from ignite.engine import Engine, Events, EventEnum
from ignite.metrics import Loss, Metric, RunningAverage
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
from ignite.exceptions import NotComputableError
from ignite.handlers.tqdm_logger import ProgressBar
from ignite.handlers import Checkpoint, DiskSaver

DEBUG:dbg:cuda:0
DEBUG:dbg:2.6.0+cu124
DEBUG:dbg:NVIDIA RTX A4000
DEBUG:dbg:['ffmpeg']


In [3]:
SHUFFLE = True
SAMPLE_RATE = 16000
MAXIMUM_SAMPLE_NUM_OF_FRAMES = 640000   #   SAMPLE_RATE*40, i.e. 40 seconds
MS_TO_NS = 1e+6

## Utility

In [4]:
from IPython.core.magic import register_cell_magic
from IPython import get_ipython

@register_cell_magic
def skip(line, cell):
    return

@register_cell_magic
def skip_if(line, cell):
    if eval(line):
        return
    get_ipython().run_cell(cell)

In [5]:
import pystoi
import pesq
from utils.ssnr import snrseg

def combine_audio(speech: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor | int) -> torch.Tensor:
    if not (torch.is_floating_point(speech) or torch.is_complex(speech)):
        # speech = torch.tensor(speech, dtype=torch.float64, device=speech.device)
        speech = speech.to(torch.float64,non_blocking=True)
    if not (torch.is_floating_point(noise) or torch.is_complex(noise)):
        # noise = torch.tensor(noise, dtype=torch.float64, device=noise.device)
        noise = noise.to(torch.float64,non_blocking=True)
    if not(type(snr) is torch.Tensor):
        snr = torch.tensor([snr])
    logger.debug(f"speech:{speech.ndim}, noise:{noise.ndim}, snr:{snr.ndim}")
    out = audioF.add_noise(speech, noise, snr).to(dtype=torch.float)
    return out


def calc_snrseg(speech: np.ndarray, processed: np.ndarray) -> float:
    v = snrseg(clean_speech=speech, processed_speech=processed, fs=SAMPLE_RATE)
    if isinstance(v, np.float64):
        v = v.item()
    return v

def calc_pesq(speech: np.ndarray, processed: np.ndarray) -> float:
    return pesq.pesq(ref=speech, deg=processed, fs=SAMPLE_RATE)

def calc_stoi(speech: np.ndarray, processed: np.ndarray) -> float:
    v = pystoi.stoi(x=speech, y=processed, fs_sig=SAMPLE_RATE)
    if isinstance(v, np.float64):
        v = v.item()
    return v

def ns_to_sec(ns: int) -> float:
    return ns/1000000000.0

def datetime_string() -> str:
    return datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")

def plot_waveform(waveform, sample_rate=SAMPLE_RATE):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")

def write_fstring_file(model_name: str, format_string: str, **args):
    with open(f"saved_models/{model_name}_{datetime_string()}.txt") as f:
        f.write(format_string.format(**args))

def standardize_batch(batch: torch.Tensor, coerce_func = lambda x: x) -> torch.Tensor:
    b = batch.squeeze()
    if len(b.shape) == 2:
        return b
    elif len(b.shape) == 1:
        return b.unsqueeze(dim=0)
    else:
        return coerce_func(b)
    
def calc_windowing(data_len: int, frame_size: int, frame_shift: int):
    num_frames = 0
    spare = 0
    c = data_len - frame_size
    if c < 1:
        return 0, 0, 0
    num_frames += 1
    num_frames += c // frame_shift
    spare = c % frame_shift
    to_pad = frame_shift - spare
    return int(num_frames), int(spare), int(to_pad)




## Dataset

In [6]:
from torch.utils.data import Dataset, DataLoader

def get_sequential_wav_paths(dir):
    count = len(glob.glob("*.wav", root_dir=dir))
    lst = []
    for i in range(1,count+1):
        lst.append(dir + "/" + str(i) + ".wav")
    
    return lst

class AudioDataset(Dataset):
    def __init__(self, data: list, root_dir: str | None = None):
        self.data = data
        if root_dir==None:
            root_dir = os.getcwd()+"/data"
        self.root_dir = root_dir
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave
    
class SortedBatchDataset(Dataset):
    def __init__(self, mixed: list, clean: list, batch_size: int):
        self.mixed = mixed
        self.clean = clean
        self.batch_size = batch_size
        # if root_dir==None:
        #     root_dir = os.getcwd()+"/data"
        # self.root_dir = root_dir

    def __len__(self):
        return math.floor(len(self.mixed)/self.batch_size)
    
    def __getitem__(self,idx):
        mixeds = []
        cleans = []
        max_len = 0
        for i in range(self.batch_size):
            if idx + i > len(self.mixed): continue
            mixed_wave, _ = torchaudio.load(self.mixed[idx*self.batch_size+i])
            clean_wave, _ = torchaudio.load(self.clean[idx*self.batch_size+i])
            mixed_wave = mixed_wave[0]
            clean_wave = clean_wave[0]
            assert(mixed_wave.shape[0]==clean_wave.shape[0])
            if i==0:
                max_len = mixed_wave.shape[0]
            else:
                mixed_wave = torch.nn.functional.pad(mixed_wave,(0,max_len-(mixed_wave.shape[0])),value=0.0)
                clean_wave = torch.nn.functional.pad(clean_wave,(0,max_len-(clean_wave.shape[0])),value=0.0)
            # logger.debug(mixed_wave)
            mixeds.append(mixed_wave)
            cleans.append(clean_wave)
        mixeds = np.asarray(mixeds)
        cleans = np.asarray(cleans)
        return torch.tensor(mixeds), torch.tensor(cleans)

    def split(self, val_pct, seed=None):
        rnd = random.Random(seed)
        this_len = len(self)
        val_batches = math.floor(this_len*val_pct)
        val_indices = sorted(rnd.sample(range(this_len), val_batches))
        train_mixed = []
        train_clean = []
        val_mixed = []
        val_clean = []
        for i in range(this_len):
            if i in val_indices:
                for x in range(self.batch_size):
                    val_mixed.append(self.mixed[i*self.batch_size+x])
                    val_clean.append(self.clean[i*self.batch_size+x])
            else:
                for x in range(self.batch_size):
                    train_mixed.append(self.mixed[i*self.batch_size+x])
                    train_clean.append(self.clean[i*self.batch_size+x])
        
        return SortedBatchDataset(train_mixed,train_clean,self.batch_size), SortedBatchDataset(val_mixed,val_clean,self.batch_size)

class FrameLoaderEvents(EventEnum):
    END_OF_BATCH = "end_of_batch"

class FrameLoader():
    '''Takes a dataloader, frame size and frame shift. It can then be iterated over to produce frames.\n
    Provides padding when sample length would be exceeded.\n
    Returns (mix, clean, has_batch_ended)'''

    def __init__(self, dl: DataLoader, frame_size: int, frame_shift: int, batch_size: int, engine: Engine | None = None, output_transform = lambda x: x):
        self.dl = dl
        self.dl_iter = iter(dl)
        self.batch_count = len(dl)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.batch_size = batch_size
        self.batch_mixed: torch.Tensor
        self.batch_clean: torch.Tensor
        self.frame_position = 0
        self.at_end = True
        self.engine = engine
        self.output_transform = output_transform
    def __iter__(self):
        self.dl_iter = iter(self.dl)
        return self
    def __next__(self):
        if self.at_end:
            batches: tuple[torch.Tensor,torch.Tensor] = next(self.dl_iter)
            self.batch_mixed = batches[0].squeeze_()
            if len(self.batch_mixed.shape) == 1:
                self.batch_mixed.unsqueeze_(0)
            self.batch_clean = batches[1].squeeze_()
            if len(self.batch_clean.shape) == 1:
                self.batch_clean.unsqueeze_(0)
            self.frame_position = 0
            self.at_end = False
            # logger.debug(f"mixed batch shape:{self.batch_mixed.shape} | clean batch shape:{self.batch_clean.shape}")
            # mix_maxes = [torch.max(s[0]) for s in self.batch_mixed]
            # clean_maxes = [torch.max(s[0]) for s in self.batch_clean]
            # logger.debug(f"mix_maxes:{str(mix_maxes)} | clean_maxes:{str(clean_maxes)}")
        
        frame_end = self.frame_position + self.frame_size
        frames = []
        for batch_i, batch in enumerate([self.batch_mixed, self.batch_clean]):
            shp = batch.shape
            frame: torch.Tensor
            if frame_end >= shp[-1]:
                if self.engine is not None and batch_i==0: 
                    self.engine.fire_event(FrameLoaderEvents.END_OF_BATCH)
                self.at_end = True
                if frame_end != shp[-1]:
                    diff = frame_end - shp[-1]
                    # Pad batch until aligned with frame_end
                    frame = torch.zeros((self.batch_size, self.frame_size), dtype=torch.float32)
                    frame[:, 0:self.frame_size - diff] = batch[:, self.frame_position:shp[-1]]
                else:
                    frame = torch.zeros((self.batch_size, self.frame_size), dtype=torch.float32)
                    frame[:, 0:self.frame_size] = batch[:, self.frame_position:frame_end]
            else:
                frame = torch.zeros((self.batch_size, self.frame_size), dtype=torch.float32)
                frame[:, 0:self.frame_size] = batch[:, self.frame_position:frame_end]
            frames.append(self.output_transform(frame))

        self.frame_position += self.frame_shift
        # logger.debug(f"FrameLoader: frame_position[{self.frame_position}] - frame_end[{frame_end}]")
        # perf_logger.debug(f"Time to load frame at ({self.frame_position}): {time.perf_counter() - start}s")

        try:

            if frames[0].shape[-1] != self.frame_size:
                logger.debug(frames[0].shape[-1])
                logger.debug("FrameLoader issue")
        except Exception as e:
            print(frames[0].shape)
            raise Exception(e.args)
        
        return frames[0], frames[1], self.at_end

class FrameReconstructor():
    '''Constructs a batch of audio samples by continuously adding (batches of) frames to the end of a buffer, excluding overlapping sections.\n
    Use `add_frame()` to return the constructed samples, up to the last batch of frames added.
    Use `add_presliced()` when audio data to add should not be treated as an overlapping frame as defined at the constructor's start.
    '''
    def __init__(self, frame_size: int, frame_shift: int, batch_size: int, output_transform = lambda x: x):
        self.audio: torch.Tensor = torch.zeros((batch_size, MAXIMUM_SAMPLE_NUM_OF_FRAMES),dtype=torch.float)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.pos: int = 0
        self.end: int = frame_size
        self.frame_slice_start: int = 0
        self.at_end = False
        self.output_transform = output_transform
    
    def add_frame(self, batch: torch.Tensor, _at_end = False):
        '''Adds a frame to the end of currently stored audio. Slices the frame to remove overlap.'''
        batch = standardize_batch(batch)
        batch = batch.reshape((self.audio.shape[0], batch.shape[-1]))
        self.audio[:,self.pos:self.end] = batch[:,self.frame_slice_start:]

        self.pos = self.end
        self.end += self.frame_shift
        self.frame_slice_start = self.frame_size - self.frame_shift

        # if _at_end: 
        #     return True
        # return False

        #   Remove if _at_end ends up doing something
        return _at_end
    
    def add_presliced(self, batch: torch.Tensor):
        '''Appends an arbitrary amount of audio data to the end of currently stored audio.'''
        batch = standardize_batch(batch)
        batch = batch.reshape((self.audio.shape[0], batch.shape[1]))
        self.audio[:,self.pos:self.pos+batch.shape[1]] = batch[:,0:]

        self.pos += batch.shape[1]
        self.end = self.pos + self.frame_shift
    
    def get_current_audio(self) -> torch.Tensor:
        out = self.audio[:,0:self.end - self.frame_shift].clone().detach()
        # out = torch.tensor(self.audio[:,0:self.end - self.frame_shift])
        out = self.output_transform(out)
        return out

    def reset(self):
        self.audio = torch.zeros(self.audio.shape,dtype=torch.float)
        self.pos = 0
        self.end = self.frame_size
        self.frame_slice_start = 0
        self.at_end = False


def get_reference_batch(ds: Dataset, frame_size: int, output_transform=lambda x: x, seed=None) -> tuple[torch.Tensor, torch.Tensor]:
    rnd = random.Random(seed)
    while True:
        idx = rnd.randint(0,len(ds)-1)
        batches = ds.__getitem__(idx)
        batch, batch2 = batches[0], batches[1]
        print(batch.shape)
        if batch.shape[-1] < frame_size:
            continue
        randpos = rnd.randint(0, batch.shape[-1]-frame_size)
        batch = batch[:, randpos:randpos+frame_size]
        batch2 = batch2[:, randpos:randpos+frame_size]
        return output_transform(batch), output_transform(batch2)

def get_all_frames(batch: torch.Tensor, frame_size: int, frame_shift: int, output_transform=lambda x: x) -> torch.Tensor:
    # logger.debug(f"before standard: {batch.shape}")
    batch = standardize_batch(batch)
    # logger.debug(f"after standard: {batch.shape}")
    out = batch.unfold(batch.ndim-1, frame_size, frame_shift)
    out = output_transform(out)
    return out

class MultiFrameLoader():
    '''Gets a 2D tensor of several overlapped frames. For TCNN.'''
    def __init__(self,dl: DataLoader, frame_size: int, frame_shift: int, batch_size: int, num_frames: int, engine: Engine | None = None, output_transform = lambda x: x):
        self.dl = dl
        self.dl_iter = iter(dl)
        self.batch_count = len(dl)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.batch_size = batch_size
        self.start_frame = 0
        self.num_frames = num_frames
        self.batch_mixed: torch.Tensor
        self.batch_clean: torch.Tensor
        self.at_end = True
        self.engine = engine
        self.output_transform = output_transform
    
    def __iter__(self):
        self.dl_iter = iter(self.dl)
        return self

    def __next__(self) -> tuple[torch.Tensor, torch.Tensor, bool]:
        if self.at_end:
            self.batch_mixed, self.batch_clean = next(self.dl_iter)
            self.batch_mixed = get_all_frames(self.batch_mixed, self.frame_size, self.frame_shift)
            self.batch_clean = get_all_frames(self.batch_clean, self.frame_size, self.frame_shift)
            self.start_frame = 0
            self.at_end = False
        spare = self.batch_mixed.shape[1] - (self.start_frame + self.num_frames)
        if spare <= 0:
            self.at_end = True
            if spare == 0:
                _num_frames = self.num_frames
            else:
                _num_frames = self.num_frames + spare   #   i.e. self.num_frames - abs(spare)
        else:
            _num_frames = self.num_frames
        out = (self.output_transform(self.batch_mixed.narrow(1, self.start_frame, _num_frames)),
                self.output_transform(self.batch_clean.narrow(1, self.start_frame, _num_frames)),
                self.at_end)
        self.start_frame += _num_frames
        return out

 

## Training Utils

In [7]:
# criterion = nn.MSELoss()

pf_train_totals = [0,0]                                                     ###
pf_train_num_loops = 0                                                      ###
pf_eval_total = 0
pf_eval_num_loops = 0

class PESQMetric(Metric):
    def __init__(self, stitch_keys=("stitch_proc","stitch_clean"), output_transform = lambda x: x, device=device):
        self.stitch_keys=stitch_keys
        self.running_total=0.0
        self.num=0
        super().__init__(output_transform, device)
    @reinit__is_reduced
    def reset(self):
        self.running_total=0.0
        self.num=0
        super().reset()
    @reinit__is_reduced
    def update(self, output):
        if len(output)<=2 or "stitch_proc" not in output[2]: return
        y_pred: np.ndarray = standardize_batch(output[2][self.stitch_keys[0]]).cpu().numpy()
        y: np.ndarray = standardize_batch(output[2][self.stitch_keys[1]]).cpu().numpy()
        for i in range(y.shape[0]):
            self.running_total += calc_pesq(y[i], y_pred[i])
            self.num += 1
        
    @sync_all_reduce("num","running_total:SUM")
    def compute(self):
        if self.num == 0:
            raise NotComputableError("PESQ Metric must have one complete sample before computing")
        return self.running_total / self.num

class STOIMetric(Metric):
    def __init__(self, stitch_keys=("stitch_proc","stitch_clean"), output_transform = lambda x: x, device=device):
        self.stitch_keys=stitch_keys
        self.running_total=0.0
        self.num=0
        super().__init__(output_transform, device)
    @reinit__is_reduced
    def reset(self):
        self.running_total=0.0
        self.num=0
        super().reset()
    @reinit__is_reduced
    def update(self, output):
        if len(output)<=2 or "stitch_proc" not in output[2]: return
        y_pred: np.ndarray = standardize_batch(output[2][self.stitch_keys[0]]).cpu().numpy()
        y: np.ndarray = standardize_batch(output[2][self.stitch_keys[1]]).cpu().numpy()
        for i in range(y.shape[0]):
            self.running_total += calc_stoi(y[i], y_pred[i])
            self.num += 1
    @sync_all_reduce("num","running_total:SUM")
    def compute(self):
        if self.num == 0:
            raise NotComputableError("STOI Metric must have one complete sample before computing")
        return self.running_total / self.num


class ValidationEvents(EventEnum):
    VALIDATION_COMPLETED = "validation_completed"

def register_custom_events(eng: Engine):
    eng.register_events(*FrameLoaderEvents)
    eng.register_events(*ValidationEvents)

def log_trainer_loss(eng: Engine):
    iterations = eng.state.iteration % eng.state.iteration_ceiling
    print(f"Epoch[{eng.state.epoch}], Iter[{iterations}] Loss: {eng.state.output}")

def log_custom(eng: Engine, **kwargs):
    full_dict = {**eng.state_dict(), "epoch": eng.state.epoch, **kwargs}
    fmt_string: str = kwargs["template"]
    print(fmt_string.format(**full_dict))

def run_eval(eng: Engine, **kwargs):
    validator: Engine = kwargs.get("validator",None)
    val_frame_loader: FrameLoader = kwargs.get("val_frame_loader",None)
    if validator is None:
        raise TypeError("log_eval_results must be passed the argument `validator` of type `Engine`")
    if val_frame_loader is None:
        raise TypeError("log_eval_results must be passed the argument `val_frame_loader` of type `FrameLoader`")
    
    validator.run(val_frame_loader)
    eng.fire_event(ValidationEvents.VALIDATION_COMPLETED)

def log_eval_results(eng: Engine, **kwargs):
    prefix = kwargs.get("prefix","")
    validator: Engine = kwargs.get("validator",None)
    if validator is None:
        raise TypeError("log_eval_results must be passed the argument `validator` of type `Engine`")
    
    metrics = validator.state.metrics
    metrics_out = kwargs.get("metrics_out",None)
    if metrics_out != None:
        metrics_out.append(metrics.copy())
    print(f"{prefix}Epoch[{eng.state.epoch}] | PESQ:[{metrics['pesq']:.2f}] | STOI:[{metrics['stoi']:.2f}] | Loss:[{metrics['loss']}]")

def set_engine_custom_keys(eng: Engine):
    eng.state_dict_user_keys.append("iteration_ceiling")
    eng.state.iteration_ceiling = sys.maxsize

def set_iteration_ceiling(eng: Engine, *args):
    if len(args)==1:
        eng.state.iteration_ceiling = args[0]
    else:
        eng.state.iteration_ceiling = eng.state.iteration

# Models

## SEGAN

In [None]:
gan_hp = {
    "frame_size":16384,
    "frame_shift":8192,
    "g_lr":1.0e-4,
    "d_lr":1.0e-4,
    "batch_size":128,
    "epochs":80,
    "save":True,
    "load":None,
    "model_type":"gan",
}

GAN_RUN_ON_LOAD = False

In [None]:
def gan(hp: dict = gan_hp):
    global pf_train_totals, pf_train_num_loops, pf_eval_total, pf_eval_num_loops
    pf_train_totals = [0,0]
    pf_train_num_loops = 0
    pf_eval_total = 0
    pf_eval_num_loops = 0
    try:
        datestring_at_start = datetime_string()
        os.mkdir(f"saved_models/gan_{datestring_at_start}")

        from models.segan import Discriminator, Generator

        torch.cuda.empty_cache()
        out = {"hp": hp}
        gen = Generator().to(device=device)
        dcrim = Discriminator().to(device=device)

        if hp["load"] != None:
            gen.load_state_dict(torch.load(hp["load"][0], weights_only=True))
            dcrim.load_state_dict(torch.load(hp["load"][1], weights_only=True))
        
        g_optimizer = torch.optim.RMSprop(gen.parameters(), lr=hp["g_lr"])
        d_optimizer = torch.optim.RMSprop(dcrim.parameters(), lr=hp["d_lr"])
        criterion = nn.L1Loss()
        out["optimizer"] = str(g_optimizer).split("(")[0]
        out["criterion"] = str(criterion).split("(")[0]


        _dataset = SortedBatchDataset(get_sequential_wav_paths("data/mixed/train"),
                                    get_sequential_wav_paths("data/speech_ordered/train"), 
                                    batch_size=hp["batch_size"])
        train_dataset, val_dataset = _dataset.split(0.2)
        del _dataset
        base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)
        base_val_dataloader = DataLoader(val_dataset)
        r = get_reference_batch(train_dataset, hp["frame_size"], lambda x: x.view(hp["batch_size"],1,-1))
        ref_batch = torch.cat((r[0],r[1]),dim=1).to(device=device)
        z = torch.zeros((hp["batch_size"],1024,8)).to(device=device)
        print(ref_batch.shape)

        def train_step(engine, batch):
            # global pf_train_totals, pf_train_num_loops
            # pf_train_forward = time.perf_counter_ns()                               ###
            dcrim.train()
            dcrim.zero_grad()
            x, y = batch[0].to(device=device), batch[1].to(device=device)
            nn.init.normal_(z)
            combined_batch = torch.cat((x.clone().detach(),y.clone().detach()),dim=1)
            output = dcrim(combined_batch, ref_batch)
            clean_loss = torch.mean((output - 1.0) ** 2)
            clean_loss.backward()

            gen_out = gen(x, z)
            output = dcrim(torch.cat((gen_out, x), dim=1),ref_batch)
            noisy_loss = torch.mean(output ** 2)
            noisy_loss.backward()

            d_optimizer.step()

            gen.train()
            gen.zero_grad()
            gen_out = gen(x, z)
            gen_noise_pair = torch.cat((gen_out, x), dim=1)
            output = dcrim(gen_noise_pair, ref_batch)

            g_loss_ = 0.5 * torch.mean((output - 1.0) ** 2)
            l1_dist = torch.abs(torch.add(gen_out, torch.neg(y)))
            g_cond_loss = 100 * torch.mean(l1_dist)
            g_loss = g_loss_ + g_cond_loss
            g_loss.backward()
            g_optimizer.step()

            return g_loss.item()
            
        trainer = Engine(train_step)
        register_custom_events(trainer)
        RunningAverage(output_transform=lambda x: x).attach(trainer,'loss')
        pbar = ProgressBar(desc="Training Epoch")
        pbar.attach(trainer,['loss'])

        trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)

        train_dataloader = FrameLoader(base_train_dataloader, hp["frame_size"], hp["frame_shift"],
                                        hp["batch_size"], engine=trainer, 
                                        output_transform=lambda x: x.view((hp["batch_size"],1,-1)))

        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"],
                                                    output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"],
                                                     output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        
        def val_step(engine, batch):
            with torch.no_grad():
                # global pf_eval_total, pf_eval_num_loops
                x, y = batch[0].to(device=device), batch[1].to(device=device)
                nn.init.normal_(z)
                y_pred = gen(x, z)
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y
        
        validator = Engine(val_step)
        pbar = ProgressBar(desc="Validation")
        pbar.attach(validator,['loss'])
        val_metrics: dict[str, Metric] = {
            "loss": Loss(criterion, output_transform=lambda x: (x[0],x[1])),
            "pesq": PESQMetric(),
            "stoi": STOIMetric()
        }
        for name, metric in val_metrics.items():
            metric.attach(validator, name)

        checkpoint_to_save = {"gen": gen, "dcrim": dcrim}
        checkpoint_handler = Checkpoint(
            checkpoint_to_save, f"saved_models/gan_{datestring_at_start}",
            filename_prefix="best", score_function=lambda eng: eng.state.metrics['pesq'],n_saved=2
        )

        metrics_out = []
        out["metrics"] = metrics_out
        val_dataloader = FrameLoader(base_val_dataloader, hp["frame_size"], hp["frame_shift"],
                                     hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        trainer.add_event_handler(Events.EPOCH_COMPLETED,run_eval,validator=validator,val_frame_loader=val_dataloader)
        trainer.add_event_handler(ValidationEvents.VALIDATION_COMPLETED,log_eval_results,validator=validator,metrics_out=metrics_out)
        validator.add_event_handler(Events.COMPLETED, checkpoint_handler)
            
        time_train = time.perf_counter_ns()
        trainer.run(train_dataloader, max_epochs=hp["epochs"])
        out["total_time"] = time.perf_counter_ns() - time_train
        # out["fwd"]=pf_train_totals[0] / float(pf_train_num_loops)                        ###
        # out["bck"]=pf_train_totals[1] / float(pf_train_num_loops)                        ###
        # out["eval"]=pf_eval_total / float(pf_eval_num_loops)                          ###


    except Exception as e:
        print("failed")
        print(traceback.format_exc())
        print(e)
    
    finally:
        if "dcrim" in locals():
            if hp["save"]:
                torch.save(gen.state_dict(),f"saved_models/gan_{datestring_at_start}/gen_final.pt")
                torch.save(dcrim.state_dict(),f"saved_models/gan_{datestring_at_start}/dcrim_final.pt")
                with open(f"saved_models/gan_{datestring_at_start}/out.json","w") as f:
                    json.dump({k: out[k] for k in out.keys() - {'gen', 'dcrim'}},f)
            return out
        else:
            return None

if GAN_RUN_ON_LOAD:
    gan()


## WaveCRN

In [None]:
crn_hp = {
    "frame_size":96,
    "frame_shift":40,
    "lr":2.0e-5,
    "batch_size":128,
    "epochs":80,
    "save":True,
    "load":None,
    "model_type":"crn",
}

CRN_RUN_ON_LOAD = False

In [None]:
def crn(hp: dict = crn_hp):
    global pf_train_totals, pf_train_num_loops, pf_eval_total, pf_eval_num_loops
    pf_train_totals = [0,0]                                                     ###
    pf_train_num_loops = 0                                                      ###
    pf_eval_total = 0
    pf_eval_num_loops = 0
    try:
        datestring_at_start = datetime_string()
        os.mkdir(f"saved_models/crn_{datestring_at_start}")

        from models.wavecrn import ConvBSRU

        torch.cuda.empty_cache()
        out = {"hp": hp}
        model = ConvBSRU(frame_size=hp["frame_size"], conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
        if hp["load"] != None:
            model.load_state_dict(torch.load(hp["load"], weights_only=True))
        out["model"] = model
        
        # optimizer = torch.optim.Adam(model.parameters(),lr=hp["lr"])
        optimizer = torch.optim.Adam(model.parameters(),lr=hp["lr"])
        criterion = nn.L1Loss()
        out["optimizer"] = str(optimizer).split("(")[0]
        out["criterion"] = str(criterion).split("(")[0]

        _dataset = SortedBatchDataset(get_sequential_wav_paths("data/mixed/train"), 
                                      get_sequential_wav_paths("data/speech_ordered/train"), 
                                      batch_size=hp["batch_size"])
        train_dataset, val_dataset = _dataset.split(0.2)
        del _dataset
        base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)
        base_val_dataloader = DataLoader(val_dataset)

        def train_step(engine, batch):
            global pf_train_totals, pf_train_num_loops
            pf_train_forward = time.perf_counter_ns()                               ###
            model.train()
            optimizer.zero_grad()
            x, y = batch[0].to(device), batch[1].to(device)
            y_pred = model(x)
            pf_train_totals[0] += (time.perf_counter_ns() - pf_train_forward)       ###
            pf_train_back = time.perf_counter_ns()                                  ###
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            pf_train_totals[1] += (time.perf_counter_ns() - pf_train_back)          ###
            pf_train_num_loops += 1                                                 ###
            return loss.item()

        trainer = Engine(train_step)
        register_custom_events(trainer)
        RunningAverage(output_transform=lambda x: x).attach(trainer,'loss')
        pbar = ProgressBar(desc="Training Epoch")
        pbar.attach(trainer,['loss'])

        trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)

        train_dataloader = FrameLoader(base_train_dataloader, hp["frame_size"], hp["frame_shift"], batch_size=hp["batch_size"], engine=trainer, output_transform=lambda x: x.view((hp["batch_size"],1,-1)))

        
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        def val_step(engine, batch):
            global pf_eval_total, pf_eval_num_loops
            pf_eval_forward = time.perf_counter_ns()                                ###
            model.eval()
            with torch.no_grad():
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)         ###
                pf_eval_num_loops += 1                                              ###
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y

        validator = Engine(val_step)
        pbar = ProgressBar(desc="Validation")
        pbar.attach(validator,['loss'])
        val_metrics: dict[str, Metric] = {
            "loss": Loss(criterion, output_transform=lambda x: (x[0],x[1])),
            "pesq": PESQMetric(),
            "stoi": STOIMetric()
        }
        for name, metric in val_metrics.items():
            metric.attach(validator, name)

        checkpoint_to_save = {"model":model}
        checkpoint_handler = Checkpoint(
            checkpoint_to_save, f"saved_models/crn_{datestring_at_start}",
            filename_prefix="best", score_function=lambda eng: eng.state.metrics['pesq'],n_saved=2
        )

        metrics_out = []
        out["metrics"] = metrics_out
        val_dataloader = FrameLoader(base_val_dataloader, hp["frame_size"], hp["frame_shift"],
                                     hp["batch_size"],output_transform=lambda x: x.view((hp["batch_size"],1,-1)))
        trainer.add_event_handler(Events.EPOCH_COMPLETED,run_eval,validator=validator,val_frame_loader=val_dataloader)
        trainer.add_event_handler(ValidationEvents.VALIDATION_COMPLETED,log_eval_results,validator=validator,metrics_out=metrics_out)
        validator.add_event_handler(Events.COMPLETED, checkpoint_handler)


        time_train = time.perf_counter_ns()
        trainer.run(train_dataloader, max_epochs=hp["epochs"])
        out["total_time"] = time.perf_counter_ns() - time_train
        out["fwd"]=pf_train_totals[0] / float(pf_train_num_loops)                        ###
        out["bck"]=pf_train_totals[1] / float(pf_train_num_loops)                        ###
        out["eval"]=pf_eval_total / float(pf_eval_num_loops)                          ###
        
        
    except Exception as e:
        print(traceback.print_exc())
        print(e)    
    finally:
        if "model" in locals():
            if hp["save"]:
                torch.save(model.state_dict(),f"saved_models/crn_{datestring_at_start}/final.pt")
                json_dict = {k: out[k] for k in out.keys() - {'model'}}
                with open(f"saved_models/crn_{datestring_at_start}/out.json","w") as f:
                    json.dump(json_dict,f)
            return out
        else:
            return None

if CRN_RUN_ON_LOAD:
    crn()

In [None]:
# if 'crn_model' in locals(): torch.save(crn_model.state_dict(),f"saved_models/crn_{datetime_string()}.pt")

## RHR-Net

In [None]:
rnn_hp = {
    "frame_size":1024,
    "frame_shift":256,
    "lr":1.0e-4,
    "batch_size":128,
    "epochs":30,
    "save":True,
    "load":None,
    "model_type":"rnn",
}

RNN_RUN_ON_LOAD = False

In [None]:
def rnn(hp: dict = rnn_hp):
    global pf_train_totals, pf_train_num_loops, pf_eval_total, pf_eval_num_loops
    pf_train_totals = [0,0]                                                     ###
    pf_train_num_loops = 0                                                      ###
    pf_eval_total = 0
    pf_eval_num_loops = 0
    try:
        datestring_at_start = datetime_string()
        os.mkdir(f"saved_models/rnn_{datestring_at_start}")
        
        import yaml
        from models.rhrnetdir.Arg_Parser import Recursive_Parse
        from models.rhrnet import RHRNet
        rnn_hp = Recursive_Parse(yaml.load(
            open('models/rhrnetdir/rhrnet_hyperparameters.yaml', encoding='utf-8'),
            Loader=yaml.Loader
            ))  
        torch.cuda.empty_cache()

        out = {"hp": hp, "datetime_str": datestring_at_start}
        model = RHRNet(rnn_hp).to(device=device)
        if hp["load"] != None:
            model.load_state_dict(torch.load(hp["load"], weights_only=True))
        out["model"] = model
        
        optimizer = torch.optim.RMSprop(model.parameters(),lr=hp["lr"])
        criterion = nn.L1Loss()
        out["optimizer"] = str(optimizer).split("(")[0]
        out["criterion"] = str(criterion).split("(")[0]

        _dataset = SortedBatchDataset(get_sequential_wav_paths("data/mixed/train"), get_sequential_wav_paths("data/speech_ordered/train"), batch_size=hp["batch_size"])
        train_dataset, val_dataset = _dataset.split(0.2)
        del _dataset
        base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)
        base_val_dataloader = DataLoader(val_dataset)
        print(f"val dataset:{len(val_dataset)}")

        def train_step(engine, batch):
            global pf_train_totals, pf_train_num_loops
            pf_train_forward = time.perf_counter_ns()                               ###
            model.train()
            optimizer.zero_grad()
            x, y = batch[0].to(device), batch[1].to(device)
            y_pred = model(x)
            pf_train_totals[0] += (time.perf_counter_ns() - pf_train_forward)       ###
            pf_train_back = time.perf_counter_ns()                                  ###
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            pf_train_totals[1] += (time.perf_counter_ns() - pf_train_back)          ###
            pf_train_num_loops += 1                                                 ###
            return loss.item()

        trainer = Engine(train_step)
        register_custom_events(trainer)
        RunningAverage(output_transform=lambda x: x).attach(trainer,'loss')
        pbar = ProgressBar(desc="Training Epoch")
        pbar.attach(trainer,['loss'])

        trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)

        train_dataloader = FrameLoader(base_train_dataloader, hp["frame_size"], hp["frame_shift"], batch_size=hp["batch_size"], engine=trainer, output_transform=lambda x: x.view((hp["batch_size"],-1)))

        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],-1)))
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"], output_transform=lambda x: x.view((hp["batch_size"],-1)))
        def val_step(engine, batch):
            global pf_eval_total, pf_eval_num_loops
            pf_eval_forward = time.perf_counter_ns()                                ###
            model.eval()
            with torch.no_grad():
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)         ###
                pf_eval_num_loops += 1                                              ###
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y

        validator = Engine(val_step)
        pbar = ProgressBar(desc="Validation")
        pbar.attach(validator,['loss'])
        val_metrics: dict[str, Metric] = {
            "loss": Loss(criterion, output_transform=lambda x: (x[0],x[1])),
            "pesq": PESQMetric(),
            "stoi": STOIMetric()
        }

        for name, metric in val_metrics.items():
            metric.attach(validator, name)

        checkpoint_to_save = {"model":model}
        checkpoint_handler = Checkpoint(
            checkpoint_to_save, f"saved_models/rnn_{datestring_at_start}",
            filename_prefix="best", score_function=lambda eng: eng.state.metrics['pesq'],n_saved=2
        )

        metrics_out = []
        out["metrics"] = metrics_out
        val_dataloader = FrameLoader(base_val_dataloader, hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        trainer.add_event_handler(Events.EPOCH_COMPLETED,run_eval,validator=validator,val_frame_loader=val_dataloader)
        trainer.add_event_handler(ValidationEvents.VALIDATION_COMPLETED,log_eval_results,validator=validator,metrics_out=metrics_out)
        validator.add_event_handler(Events.COMPLETED, checkpoint_handler)

        time_train = time.perf_counter_ns()
        trainer.run(train_dataloader, max_epochs=hp["epochs"])
        out["total_time"] = time.perf_counter_ns() - time_train
        out["fwd"]=pf_train_totals[0] / float(pf_train_num_loops)                        ###
        out["bck"]=pf_train_totals[1] / float(pf_train_num_loops)                        ###
        out["eval"]=pf_eval_total / float(pf_eval_num_loops)                          ###
        
    except Exception as e:
        print(e)    
    finally:
        if "model" in locals():
            if hp["save"]:
                torch.save(model.state_dict(),f"saved_models/rnn_{datestring_at_start}/final.pt")
                with open(f"saved_models/rnn_{datestring_at_start}/out.json","w") as f:
                    json.dump({k: out[k] for k in out.keys() - {'model'}},f)
            return out
        else:
            return None

if RNN_RUN_ON_LOAD:
    rnn()

In [None]:
# a=ns_to_sec(pf_train_totals[0] / float(pf_train_num_loops))                         ###
# b=ns_to_sec(pf_train_totals[1] / float(pf_train_num_loops))                         ###
# c=ns_to_sec(pf_eval_total / float(pf_train_num_loops/4))                            ###
# print(f"train forward:{a:.20f} | backprop:{b:.20f} | eval forward:{c:.20f}")

In [None]:
# if 'rnn_model' in locals(): torch.save(rnn_model.state_dict(),f"saved_models/rnn_{datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")}.pt")

## TCNN

In [None]:
cnn_hp = {
    "frame_size":320,
    "frame_shift":160,
    "num_frames": 300,
    "lr":1.0e-3,
    "batch_size":16,
    "epochs":30,
    "save":False,
    "load":None,
    "model_type":"cnn",
}

CNN_RUN_ON_LOAD = False

In [None]:
def cnn(hp: dict = cnn_hp):
    global pf_train_totals, pf_train_num_loops, pf_eval_total, pf_eval_num_loops
    pf_train_totals = [0,0]                                                     ###
    pf_train_num_loops = 0                                                      ###
    pf_eval_total = 0
    pf_eval_num_loops = 0
    logging.disable(logging.DEBUG)
    try:
        datestring_at_start = datetime_string()
        os.mkdir(f"saved_models/cnn_{datestring_at_start}")

        from models.tcnn import TCNN

        torch.cuda.empty_cache()
        out = {"hp": hp}
        model = TCNN().to(device=device)
        if hp["load"] != None:
            model.load_state_dict(torch.load(hp["load"], weights_only=True))
        out["model"] = model

        optimizer = torch.optim.Adam(model.parameters(),lr=hp["lr"])
        criterion = nn.L1Loss()
        out["optimizer"] = str(optimizer).split("(")[0]
        out["criterion"] = str(criterion).split("(")[0]

        _dataset = SortedBatchDataset(get_sequential_wav_paths("data/mixed/train"), 
                                      get_sequential_wav_paths("data/speech_ordered/train"), 
                                      batch_size=hp["batch_size"])
        train_dataset, val_dataset = _dataset.split(0.2)
        del _dataset
        base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)
        base_val_dataloader = DataLoader(val_dataset)

        def train_step(engine, batch):
            global pf_train_totals, pf_train_num_loops
            pf_train_forward = time.perf_counter_ns()                               ###
            model.train()
            optimizer.zero_grad()
            x, y = batch[0].to(device), batch[1].to(device)
            y_pred = model(x)
            pf_train_totals[0] += (time.perf_counter_ns() - pf_train_forward)       ###
            pf_train_back = time.perf_counter_ns()                                  ###
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            pf_train_totals[1] += (time.perf_counter_ns() - pf_train_back)          ###
            pf_train_num_loops += 1                                                 ###
            return loss.item()
        
        trainer = Engine(train_step)
        register_custom_events(trainer)
        RunningAverage(output_transform=lambda x: x).attach(trainer,'loss')
        pbar = ProgressBar(desc="Training Epoch")
        pbar.attach(trainer,['loss'])

        trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)

        train_dataloader = MultiFrameLoader(base_train_dataloader, hp["frame_size"], hp["frame_shift"], hp["batch_size"],
                                              hp["num_frames"], engine=trainer, output_transform=lambda x: x.unsqueeze(1))
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        def val_step(engine, batch):
            global pf_eval_total, pf_eval_num_loops
            pf_eval_forward = time.perf_counter_ns()                                ###
            model.eval()
            with torch.no_grad():
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)         ###
                pf_eval_num_loops += 1                                              ###
                for i in range(y_pred.shape[2]):
                    proc_frame_constructor.add_frame(y_pred[:,:,i,:])
                    clean_frame_constructor.add_frame(y[:,:,i,:])
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y
        
        validator = Engine(val_step)
        val_metrics: dict[str, Metric] = {
            "loss": Loss(criterion, output_transform=lambda x: (x[0],x[1])),
            "pesq": PESQMetric(),
            "stoi": STOIMetric()
        }
        for name, metric in val_metrics.items():
            metric.attach(validator, name)
        RunningAverage(output_transform=lambda x: criterion(x[0],x[1]).item()).attach(validator,'running_loss')
        pbar = ProgressBar(desc="Validation")
        pbar.attach(validator,['running_loss'])
        
        checkpoint_to_save = {"model":model}
        checkpoint_handler = Checkpoint(
            checkpoint_to_save, f"saved_models/cnn_{datestring_at_start}",
            filename_prefix="best", score_function=lambda eng: eng.state.metrics['pesq'],n_saved=2
        )
        metrics_out = []
        out["metrics"] = metrics_out
        val_dataloader = MultiFrameLoader(base_val_dataloader, hp["frame_size"], hp["frame_shift"], hp["batch_size"],
                                              hp["num_frames"], engine=trainer, output_transform=lambda x: x.unsqueeze(1))
        trainer.add_event_handler(Events.EPOCH_COMPLETED,run_eval,validator=validator,val_frame_loader=val_dataloader)
        trainer.add_event_handler(ValidationEvents.VALIDATION_COMPLETED,log_eval_results,validator=validator,metrics_out=metrics_out)
        validator.add_event_handler(Events.COMPLETED, checkpoint_handler)


        time_train = time.perf_counter_ns()
        trainer.run(train_dataloader, max_epochs=hp["epochs"])
        out["total_time"] = time.perf_counter_ns() - time_train
        out["fwd"]=pf_train_totals[0] / float(pf_train_num_loops)                        ###
        out["bck"]=pf_train_totals[1] / float(pf_train_num_loops)                        ###
        out["eval"]=pf_eval_total / float(pf_eval_num_loops)                          ###
        
        
    except Exception as e:
        print(traceback.print_exc())
        print(e) 
    finally:
        logging.disable(logging.NOTSET)
        if "model" in locals():
            if hp["save"]:
                torch.save(model.state_dict(),f"saved_models/cnn_{datestring_at_start}/final.pt")
                json_dict = {k: out[k] for k in out.keys() - {'model'}}
                with open(f"saved_models/cnn_{datestring_at_start}/out.json","w") as f:
                    json.dump(json_dict,f)
            return out
        else:
            return None

if CNN_RUN_ON_LOAD:
    cnn()

# Model Evaluation

## Model Testing

In [None]:
from models.segan import Generator
from models.tcnn import TCNN
from models.rhrnet import RHRNet
from models.wavecrn import ConvBSRU
from pesq.cypesq import NoUtterancesError
from utils.rhr_hp_load import load_rnn_hp

In [None]:
def gan_test(model: Generator, dl: FrameLoader, hp: dict = gan_hp):
    results = []
    try:
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        z = torch.zeros((hp["batch_size"],1024,8)).to(device=device)
        model.eval()
        with torch.no_grad():
            for batch in tqdm(dl):
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = model(x,z)
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio().numpy(force=True)
                    y_stitch = clean_frame_constructor.get_current_audio().numpy(force=True)
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    for i in range(y_stitch.shape[0]):
                        try:
                            sq = calc_pesq(y_stitch[i], y_pred_stitch[i])
                            si = calc_stoi(y_stitch[i], y_pred_stitch[i])
                            ssnr = calc_snrseg(y_stitch[i], y_pred_stitch[i])
                            results.append({"pesq":sq, "stoi":si, "ssnr":ssnr})
                        except NoUtterancesError as e:
                            print("NoUtterancesError")
                        except Exception as e:
                            print("exception")
    except:
        traceback.print_exc()
    finally:
        return results

def crn_test(model: ConvBSRU, dl: FrameLoader, hp: dict = crn_hp):
    results = []
    try:
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        model.eval()
        with torch.no_grad():
            for batch in tqdm(dl):
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = model(x)
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio().numpy(force=True)
                    y_stitch = clean_frame_constructor.get_current_audio().numpy(force=True)
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    for i in range(y_stitch.shape[0]):
                        try:
                            sq = calc_pesq(y_stitch[i], y_pred_stitch[i])
                            si = calc_stoi(y_stitch[i], y_pred_stitch[i])
                            ssnr = calc_snrseg(y_stitch[i], y_pred_stitch[i])
                            results.append({"pesq":sq, "stoi":si, "ssnr":ssnr})
                        except Exception as e:
                            traceback.print_exc()
                            raise Exception()
    except:
        traceback.print_exc()
    finally:
        return results
    
def rnn_test(model: RHRNet, dl: FrameLoader, hp: dict = rnn_hp):
    results = []
    try:
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        model.eval()
        with torch.no_grad():
            for batch in tqdm(dl):
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = model(x)
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio().numpy(force=True)
                    y_stitch = clean_frame_constructor.get_current_audio().numpy(force=True)
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    for i in range(y_stitch.shape[0]):
                        try:
                            sq = calc_pesq(y_stitch[i], y_pred_stitch[i])
                            si = calc_stoi(y_stitch[i], y_pred_stitch[i])
                            ssnr = calc_snrseg(y_stitch[i], y_pred_stitch[i])
                            results.append({"pesq":sq, "stoi":si, "ssnr":ssnr})
                        except NoUtterancesError as e:
                            print("NoUtterancesError")
                            continue
    except:
        traceback.print_exc()
    finally:
        return results

def cnn_test(model: TCNN, dl: MultiFrameLoader, hp: dict = cnn_hp):
    results = []
    try:
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], hp["batch_size"])
        model.eval()
        with torch.no_grad():
            for batch in tqdm(dl):
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = model(x)
                for i in range(y_pred.shape[2]):
                    # print("shape" + str(y_pred.shape))
                    # print("slice shape" + str(y_pred[:,:,i,:].shape))
                    # print("std slice shape" + str(standardize_batch(y_pred[:,:,i,:]).shape))
                    proc_frame_constructor.add_frame(y_pred[:,:,i,:])
                    clean_frame_constructor.add_frame(y[:,:,i,:])
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio().numpy(force=True)
                    y_stitch = clean_frame_constructor.get_current_audio().numpy(force=True)
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    for i in range(y_stitch.shape[0]):
                        try:
                            sq = calc_pesq(y_stitch[i], y_pred_stitch[i])
                            si = calc_stoi(y_stitch[i], y_pred_stitch[i])
                            ssnr = calc_snrseg(y_stitch[i], y_pred_stitch[i])
                            results.append({"pesq":sq, "stoi":si, "ssnr":ssnr})
                        except NoUtterancesError as e:
                            print("NoUtterancesError")
                            continue
    except:
        traceback.print_exc()
    finally:
        return results

run_base_test = False
if run_base_test:
    snr_version = "high_snr"
    #gan
    print("Testing GAN...")
    model = Generator().to(device=device)
    model.load_state_dict(torch.load(r"/vol/research/FYP_Leo/speech_denoising_fyp/saved_models/final/gan_29-04-2025--11-58-59/best_checkpoint_1.8107.pt")["gen"])
    ds = SortedBatchDataset(get_sequential_wav_paths(f"data/mixed/test/{snr_version}"),get_sequential_wav_paths(f"data/speech_ordered/test/{snr_version}"),gan_hp["batch_size"])
    dl = FrameLoader(ds, gan_hp["frame_size"], gan_hp["frame_shift"], gan_hp["batch_size"], 
                    output_transform=lambda x: x.view((gan_hp["batch_size"],1,-1)))
    gan_results = gan_test(model, dl, gan_hp)
    # with open("results/gan_test-"+datetime_string()+".json","w") as file:
    #     json.dump(gan_results, file)
    del model, ds, dl
    torch.cuda.empty_cache()

    #crn
    print("Testing CRN...")
    model = ConvBSRU(frame_size=crn_hp["frame_size"], conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
    model.load_state_dict(torch.load(r"/vol/research/FYP_Leo/speech_denoising_fyp/saved_models/final/crn_20-04-2025--12-49-08/best_model_1.7157.pt"))
    ds = SortedBatchDataset(get_sequential_wav_paths(f"data/mixed/test/{snr_version}"),get_sequential_wav_paths(f"data/speech_ordered/test/{snr_version}"),crn_hp["batch_size"])
    dl = FrameLoader(ds, crn_hp["frame_size"], crn_hp["frame_shift"], crn_hp["batch_size"], 
                    output_transform=lambda x: x.view((crn_hp["batch_size"],1,-1)))
    crn_results = crn_test(model, dl, crn_hp)
    # with open("results/crn_test-"+datetime_string()+".json","w") as file:
    #     json.dump(crn_results, file)
    del model, ds, dl
    torch.cuda.empty_cache()

    #rnn
    print("Testing RNN...")
    import yaml
    from models.rhrnetdir.Arg_Parser import Recursive_Parse
    _rnn_hp = Recursive_Parse(yaml.load(
        open('models/rhrnetdir/rhrnet_hyperparameters.yaml', encoding='utf-8'),
        Loader=yaml.Loader
        )) 
    model = RHRNet(_rnn_hp).to(device=device)
    model.load_state_dict(torch.load(r"/vol/research/FYP_Leo/speech_denoising_fyp/saved_models/final/rnn_21-04-2025--01-00-43/best_model_2.0291.pt"))
    ds = SortedBatchDataset(get_sequential_wav_paths(f"data/mixed/test/{snr_version}"),get_sequential_wav_paths(f"data/speech_ordered/test/{snr_version}"),rnn_hp["batch_size"])
    dl = FrameLoader(ds, rnn_hp["frame_size"], rnn_hp["frame_shift"], rnn_hp["batch_size"])
    rnn_results = rnn_test(model, dl, rnn_hp)
    # with open("results/rnn_test-"+datetime_string()+".json","w") as file:
    #     json.dump(rnn_results, file)
    del model, ds, dl
    torch.cuda.empty_cache()

    #cnn
    print("Testing CNN...")
    model = TCNN().to(device=device)
    model.load_state_dict(torch.load(r"/vol/research/FYP_Leo/speech_denoising_fyp/saved_models/final/cnn_08-05-2025--12-02-34/best_model_1.9495.pt"))
    ds = SortedBatchDataset(get_sequential_wav_paths(f"data/mixed/test/{snr_version}"),get_sequential_wav_paths(f"data/speech_ordered/test/{snr_version}"),cnn_hp["batch_size"])
    dl = MultiFrameLoader(ds,cnn_hp["frame_size"], cnn_hp["frame_shift"],cnn_hp["batch_size"],cnn_hp["num_frames"],output_transform=lambda x: x.unsqueeze(1))
    cnn_results = cnn_test(model, dl, cnn_hp)
    # with open("results/cnn_test-"+datetime_string()+".json","w") as file:
    #     json.dump(cnn_results, file)
    del model, ds, dl
    torch.cuda.empty_cache()

    all_res = {"gan":gan_results, "crn":crn_results, "rnn":rnn_results, "cnn":cnn_results}
    with open(f"results/{snr_version}_{datetime_string()}.json","w") as file:
        json.dump(all_res, file)


## Performance Speed Testing

In [None]:
def gan_one_sec_test(model: Generator, dl: DataLoader, hp: dict, n_samples = None):
    try:
        results = []
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        sample: torch.Tensor
        rnd = random.Random()
        z = torch.zeros((hp["batch_size"],1024,8)).to(device=device)
        i = 0
        for sample in tqdm(dl):
            perf_time = time.perf_counter_ns()
            if i >= n_samples:
                break
            i+=1
            start = rnd.randint(0, sample[0].shape[-1]-hp["frame_size"])
            slc = (start, start + hp["frame_size"])
            noisy, clean = sample[0].squeeze()[slice(*slc)].clone().detach().view((1,1,-1)), sample[1].squeeze()[slice(*slc)].clone().detach().view((1,1,-1))
            with torch.no_grad():
                x, y = noisy.to(device=device), clean.to(device=device)
                nn.init.normal_(z)
                y_pred: torch.Tensor = model(x, z)
            perf_time = time.perf_counter_ns() - perf_time
            y_pred_np = y_pred.squeeze().numpy(force=True)
            y_np = y.squeeze().numpy(force=True)
            # x_np = x.numpy(force=True)
            try:
                sq = calc_pesq(y_np, y_pred_np)
                si = calc_stoi(y_np, y_pred_np)
                ssnr = calc_snrseg(y_np, y_pred_np)
            except:
                # traceback.print_exc()
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr": ssnr, "time": perf_time
            }
            results.append(res)
    except:
        traceback.print_exc()
        pass
    finally:
        return results
    
def crn_one_sec_test(model: ConvBSRU, dl: DataLoader, hp: dict, n_samples = None):
    '''`dl` should be a DataLoader that provides the full audio file, not a FrameLoader.'''
    try:
        results = []
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        sample: torch.Tensor
        rnd = random.Random()
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        # noisy_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        z = torch.zeros((hp["batch_size"],1024,8)).to(device=device)
        i = 0
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = rnd.randint(0, sample[0].shape[-1]-SAMPLE_RATE)
            slc = (start, start + SAMPLE_RATE)
            _noisy, _clean = sample[0][:,:,slice(*slc)].clone().detach(), sample[1][:,:,slice(*slc)].clone().detach()
            _,_,pad = calc_windowing(_noisy.shape[-1],hp["frame_size"], hp["frame_shift"])
            _noisy, _clean = torch.nn.functional.pad(_noisy,(0,pad), value=0.0), torch.nn.functional.pad(_clean,(0,pad), value=0.0)
            noisy, clean = get_all_frames(_noisy,hp["frame_size"], hp["frame_shift"]), get_all_frames(_clean,hp["frame_size"], hp["frame_shift"])
            perf_time = time.perf_counter_ns()
            for j in range(noisy.shape[1]):
                batch = noisy.narrow(1, j, 1), clean.narrow(1, j, 1)    #   Frame num dimension becomes 1, can be reused as channel index
                with torch.no_grad():
                    x, y = batch[0].to(device=device), batch[1].to(device=device)
                    y_pred: torch.Tensor = model(x, z)
                    proc_frame_constructor.add_frame(y_pred)
                    clean_frame_constructor.add_frame(y)
                    # noisy_frame_constructor.add_frame(x)

            perf_time = time.perf_counter_ns() - perf_time
            y_pred = proc_frame_constructor.get_current_audio().squeeze()
            y = clean_frame_constructor.get_current_audio().squeeze()
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()
            # x = noisy_frame_constructor.get_current_audio()
            y_pred_np = y_pred.numpy(force=True)
            y_np = y.numpy(force=True)
            # x_np = x.numpy(force=True)
            try:
                sq = calc_pesq(y_np, y_pred_np)
                si = calc_stoi(y_np, y_pred_np)
                ssnr = calc_snrseg(y_np, y_pred_np)
            except:
                traceback.print_exc()
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time
            }
            results.append(res)
    except:
        traceback.print_exc()
    finally:
        return results

def rnn_one_sec_test(model: RHRNet, dl: DataLoader, hp: dict, n_samples = None):
    try:
        results = []
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        sample: torch.Tensor
        rnd = random.Random()
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        # noisy_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        i = 0
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = rnd.randint(0, sample[0].shape[-1]-SAMPLE_RATE)
            slc = (start, start + SAMPLE_RATE)
            _noisy, _clean = sample[0].squeeze()[slice(*slc)].clone().detach(), sample[1].squeeze()[slice(*slc)].clone().detach()
            _,_,pad = calc_windowing(_noisy.shape[-1],hp["frame_size"], hp["frame_shift"])
            noisy_pad, clean_pad = torch.nn.functional.pad(_noisy,(0,pad), value=0.0), torch.nn.functional.pad(_clean,(0,pad), value=0.0)
            noisy, clean = get_all_frames(noisy_pad,hp["frame_size"], hp["frame_shift"]), get_all_frames(clean_pad,hp["frame_size"], hp["frame_shift"])
            perf_time = time.perf_counter_ns()
            for j in range(noisy.shape[1]):
                batch = noisy.narrow(1, j, 1).squeeze(1), clean.narrow(1, j, 1).squeeze(1)
                with torch.no_grad():
                    x, y = batch[0].to(device=device), batch[1].to(device=device)
                    y_pred: torch.Tensor = model(x)
                    proc_frame_constructor.add_frame(y_pred)
                    clean_frame_constructor.add_frame(y)
                    # noisy_frame_constructor.add_frame(x)

            perf_time = time.perf_counter_ns() - perf_time
            y_pred = proc_frame_constructor.get_current_audio().squeeze()
            y = clean_frame_constructor.get_current_audio().squeeze()
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()
            # display.display(Audio(y_pred, rate=SAMPLE_RATE))
            # display.display(Audio(y, rate=SAMPLE_RATE))
            # x = noisy_frame_constructor.get_current_audio()
            y_pred_np = y_pred.numpy(force=True)
            y_np = y.numpy(force=True)
            # x_np = x.numpy(force=True)
            try:
                sq = calc_pesq(y_np, y_pred_np)
                si = calc_stoi(y_np, y_pred_np)
                ssnr = calc_snrseg(y_np, y_pred_np)
            except NoUtterancesError as e:
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time
            }
            results.append(res)
    except:
        traceback.print_exc()
    finally:
        return results

def cnn_one_sec_test(model: TCNN, dl: DataLoader, hp: dict, n_samples = None):
    try:
        results = []
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        sample: torch.Tensor
        rnd = random.Random()
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        # noisy_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        i = 0
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = rnd.randint(0, sample[0].shape[-1]-SAMPLE_RATE)
            slc = (start, start + SAMPLE_RATE)
            _noisy, _clean = sample[0][:,:,slice(*slc)].clone().detach(), sample[1][:,:,slice(*slc)].clone().detach()
            _,_,pad = calc_windowing(_noisy.shape[-1],hp["frame_size"], hp["frame_shift"])
            _noisy, _clean = torch.nn.functional.pad(_noisy,(0,pad), value=0.0), torch.nn.functional.pad(_clean,(0,pad), value=0.0)
            noisy, clean = get_all_frames(_noisy,hp["frame_size"], hp["frame_shift"]), get_all_frames(_clean,hp["frame_size"], hp["frame_shift"])
            j = 0
            j_end = noisy.shape[1]
            perf_time = time.perf_counter_ns()
            _num_frames = min(hp["num_frames"],j_end - j)
            while j < j_end:
                batch = noisy.narrow(1,j,_num_frames).unsqueeze(1), clean.narrow(1,j,_num_frames).unsqueeze(1)
                j += hp["num_frames"]
                with torch.no_grad():
                    x, y = batch[0].to(device=device), batch[1].to(device=device)
                    y_pred: torch.Tensor = model(x)
                    for i in range(y_pred.shape[2]):
                        proc_frame_constructor.add_frame(y_pred[:,:,i,:])
                        clean_frame_constructor.add_frame(y[:,:,i,:])
                        # noisy_frame_constructor.add_frame(x[:,:,i,:])


            perf_time = time.perf_counter_ns() - perf_time
            y_pred = proc_frame_constructor.get_current_audio().squeeze()
            y = clean_frame_constructor.get_current_audio().squeeze()
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()
            # x = clean_frame_constructor.get_current_audio()
            y_pred_np = y_pred.numpy(force=True)
            y_np = y.numpy(force=True)
            # x_np = x.numpy(force=True)
            try:
                sq = calc_pesq(y_np, y_pred_np)
                si = calc_stoi(y_np, y_pred_np)
                ssnr = calc_snrseg(y_np, y_pred_np)
            except:
                # traceback.print_exc()
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time
            }
            results.append(res)
    except:
        traceback.print_exc()
    finally: 
        return results

run_onesec_test = False
if run_onesec_test:
    snr_version = "low_snr"
    ds = SortedBatchDataset(get_sequential_wav_paths(f"data/mixed/test/{snr_version}"), get_sequential_wav_paths(f"data/speech_ordered/test/{snr_version}"),1)
    logging.disable(logging.DEBUG)

    gan_results = []
    print("Testing GAN...")
    dl = DataLoader(ds)
    model = Generator().to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/gan_29-04-2025--11-58-59/best_checkpoint_1.8107.pt")["gen"])
    gan_test_hp = {"frame_size":16384, "frame_shift": 16384 - (16384//4), "batch_size":1}
    gan_results = gan_one_sec_test(model, dl, gan_test_hp)
    del model, dl
    torch.cuda.empty_cache()

    crn_results = []
    print("Testing CRN...")
    dl = DataLoader(ds)
    model = ConvBSRU(frame_size=crn_hp["frame_size"], conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/crn_20-04-2025--12-49-08/best_model_1.7157.pt"))
    crn_test_hp = {"frame_size":96, "frame_shift": 96 - (96//4), "batch_size":1}
    crn_results = crn_one_sec_test(model, dl, crn_test_hp)
    del model, dl
    torch.cuda.empty_cache()

    rnn_results = []
    print("Testing RNN...")
    dl = DataLoader(ds)
    _rnn_hp = load_rnn_hp()
    model = RHRNet(_rnn_hp).to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/rnn_21-04-2025--01-00-43/best_model_2.0291.pt"))
    rnn_test_hp = {"frame_size":320, "frame_shift": 320 - (320//4), "batch_size":1}
    rnn_results = rnn_one_sec_test(model, dl, rnn_test_hp)
    del model, dl
    torch.cuda.empty_cache()

    cnn_results = []
    print("Testing CNN...")
    dl = DataLoader(ds)
    model = TCNN().to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/cnn_08-05-2025--12-02-34/best_model_1.9495.pt"))
    cnn_test_hp = {"frame_size":320, "frame_shift": 320 - (320//4),"num_frames": 300, "batch_size":1}
    cnn_results = cnn_one_sec_test(model, dl, cnn_test_hp)
    print(cnn_results)
    del model, dl
    torch.cuda.empty_cache()


    all_res = {"gan":gan_results, "crn":crn_results, "rnn":rnn_results, "cnn":cnn_results}
    with open(f"results/{snr_version}_one-sec_{datetime_string()}.json","w") as file:
        json.dump(all_res, file)

### One test will be how fast can it do 1 second of audio
### One test will be at max 10ms shift


In [None]:
def gan_10ms_test(model: Generator, dl: DataLoader, hp: dict, n_samples = None):
    frame_shift = 160
    results = []
    try:
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        z = torch.zeros((1,1024,8)).to(device=device)
        i = 0
        sample: list[torch.Tensor]
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = frame_shift - hp["frame_size"]
            end = frame_shift
            x: torch.Tensor
            y: torch.Tensor
            at_end = False
            perf_time = 0
            j=0
            while not at_end:
                j+=1
                if start < 0:
                    x = torch.nn.functional.pad(sample[0].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,1,-1).to(device=device)
                elif end > sample[0].shape[-1]:
                    x = torch.nn.functional.pad(sample[0].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,1,-1).to(device=device)
                    at_end = True
                else:
                    x = sample[0].squeeze()[start:end].clone().detach().view(1,1,-1).to(device=device)
                    y = sample[1].squeeze()[start:end].clone().detach().view(1,1,-1).to(device=device)
                
                _perf_time = time.perf_counter_ns()
                y_pred: torch.Tensor = model(x, z)
                perf_time += (time.perf_counter_ns() - _perf_time)
                proc_frame_constructor.add_presliced(y_pred[:,:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                clean_frame_constructor.add_presliced(y[:,:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                start += frame_shift
                end += frame_shift
            
            y_pred = proc_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            y = clean_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()

            try:
                sq = calc_pesq(y, y_pred)
                si = calc_stoi(y, y_pred)
                ssnr = calc_snrseg(y, y_pred)
            except NoUtterancesError as e:
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time, "audio_sample_len": sample[0].shape[-1], "frames":j
            }
            results.append(res)
                
    except:
        traceback.print_exc()
    finally:
        return results

def crn_10ms_test(model: ConvBSRU, dl: FrameLoader, hp: dict, n_samples = None):
    results = []
    try:
        if n_samples == None:
            n_samples = 99999999999999999
        model.eval()
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], hp["frame_shift"], 1)
        i = 1
        perf_time = 0
        j=0
        for batch in tqdm(dl):
            j+=1
            x, y = batch[0].to(device), batch[1].to(device)
            _perf_time = time.perf_counter_ns()
            y_pred = model(x)
            perf_time += time.perf_counter_ns() - _perf_time
            proc_frame_constructor.add_frame(y_pred.clone().detach())
            clean_frame_constructor.add_frame(y.clone().detach())
            if batch[2]:
                y_pred = proc_frame_constructor.get_current_audio().squeeze().numpy(force=True)
                y = clean_frame_constructor.get_current_audio().squeeze().numpy(force=True)
                proc_frame_constructor.reset()
                clean_frame_constructor.reset()

                try:
                    sq = calc_pesq(y, y_pred)
                    si = calc_stoi(y, y_pred)
                    ssnr = calc_snrseg(y, y_pred)
                except NoUtterancesError as e:
                    continue
                res = {
                    "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time, "audio_sample_len": y_pred.shape[-1], "frames":j
                }
                results.append(res)

                if i >= n_samples:
                    break
                i+=1
                j=0
                
    except:
        traceback.print_exc()
    finally:
        return results

def rnn_10ms_test(model: RHRNet, dl: DataLoader, hp: dict, n_samples = None):
    frame_shift = 160
    results = []
    try:
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        i = 0
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = frame_shift - hp["frame_size"]
            end = frame_shift
            x: torch.Tensor
            y: torch.Tensor
            at_end = False
            perf_time = 0
            j=0
            while not at_end:
                j+=1
                if start < 0:
                    __x = sample[0][:,0:end]
                    x = torch.nn.functional.pad(sample[0].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,-1).to(device=device)
                elif end > sample[0].shape[-1]:
                    x = torch.nn.functional.pad(sample[0].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,-1).to(device=device)
                    at_end = True
                else:
                    x = sample[0].squeeze()[start:end].clone().detach().view(1,-1).to(device=device)
                    y = sample[1].squeeze()[start:end].clone().detach().view(1,-1).to(device=device)
                
                _perf_time = time.perf_counter_ns()
                y_pred: torch.Tensor = model(x)
                perf_time += (time.perf_counter_ns() - _perf_time)
                proc_frame_constructor.add_presliced(y_pred[:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                clean_frame_constructor.add_presliced(y[:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                start += frame_shift
                end += frame_shift
            
            y_pred = proc_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            y = clean_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()

            try:
                sq = calc_pesq(y, y_pred)
                si = calc_stoi(y, y_pred)
                ssnr = calc_snrseg(y, y_pred)
            except NoUtterancesError as e:
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time, "audio_sample_len": y_pred.shape[-1]
            }
            results.append(res)
                
    except:
        traceback.print_exc()
    finally:
        return results

def cnn_10ms_test(model: TCNN, dl: DataLoader, hp: dict, n_samples = None):
    from collections import deque
    frame_shift = 160
    results = []
    try:
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        sample: torch.Tensor
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        i = 0
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = frame_shift - hp["frame_size"]
            x: torch.Tensor
            y: torch.Tensor
            at_end = False
            perf_time = 0
            x_pad = torch.nn.functional.pad(sample[0].squeeze(),(abs(start),0),value=0.0)
            y_pad = torch.nn.functional.pad(sample[1].squeeze(),(abs(start),0),value=0.0)
            _,_,pad = calc_windowing(x_pad.shape[0],hp["frame_size"],frame_shift)
            x_pad = torch.nn.functional.pad(x_pad,(0,pad),value=0.0)
            y_pad = torch.nn.functional.pad(y_pad,(0,pad),value=0.0)
            x_frames = get_all_frames(x_pad,hp["frame_size"],frame_shift)
            y_frames = get_all_frames(y_pad,hp["frame_size"],frame_shift)
            
            first_window=0
            windows_count=0
            j=0
            y_pred: torch.Tensor
            while not at_end:
                j+=1
                if first_window+windows_count >= x_frames.shape[1]-1:
                    at_end = True
                # print(j, flush=True)
                windows_count+=1
                if windows_count >= hp["num_frames"]:
                    windows_count = hp["num_frames"]
                    first_window += 1
                x = x_frames.narrow(1,first_window, windows_count)
                x = x.view(1,1,x.shape[1], hp["frame_size"]).to(device=device)
                y = y_frames.narrow(1,first_window, windows_count)
                y = y.view(1,1,y.shape[1], hp["frame_size"]).to(device=device)
                
                _perf_time = time.perf_counter_ns()
                y_pred = model(x)
                proc_frame_constructor.add_presliced(y_pred[0,0,-1,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                clean_frame_constructor.add_presliced(y[0,0,-1,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                perf_time += (time.perf_counter_ns() - _perf_time)
            
            y_pred = proc_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            y = clean_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()

            try:
                sq = calc_pesq(y, y_pred)
                si = calc_stoi(y, y_pred)
                ssnr = calc_snrseg(y, y_pred)
            except NoUtterancesError as e:
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time, "audio_sample_len": y_pred.shape[-1], "frames":j
            }
            results.append(res)
                
    except:
        traceback.print_exc()
    finally:
        return results

run_10ms_tests = True
if run_10ms_tests:
    snr_version = "high_snr"
    ds = SortedBatchDataset(get_sequential_wav_paths(f"data/mixed/test/{snr_version}"), get_sequential_wav_paths(f"data/speech_ordered/test/{snr_version}"),1)
    logging.disable(logging.DEBUG)

    gan_results = []
    print("Testing GAN...")
    gan_test_hp = {"frame_size":16384, "frame_shift": 160, "batch_size":1}
    dl = DataLoader(ds)
    model = Generator().to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/gan_29-04-2025--11-58-59/best_checkpoint_1.8107.pt")["gen"])
    gan_results = gan_10ms_test(model, dl, gan_test_hp,1000)
    del model, dl
    torch.cuda.empty_cache()

    crn_results = []
    print("Testing CRN...")
    crn_test_hp = {"frame_size":96, "frame_shift": 48, "batch_size":1}
    dl = DataLoader(ds)
    fl = FrameLoader(dl, crn_test_hp["frame_size"],crn_test_hp["frame_shift"],1, output_transform=lambda x: x.view(1,1,-1))
    model = ConvBSRU(frame_size=crn_test_hp["frame_size"], conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/crn_20-04-2025--12-49-08/best_model_1.7157.pt"))
    crn_results = crn_10ms_test(model, fl, crn_test_hp,1000)
    del model, dl
    torch.cuda.empty_cache()

    rnn_results = []
    print("Testing RNN...")
    rnn_test_hp = {"frame_size":320, "frame_shift": 160, "batch_size":1}
    dl = DataLoader(ds)
    _rnn_hp = load_rnn_hp()
    model = RHRNet(_rnn_hp).to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/rnn_21-04-2025--01-00-43/best_model_2.0291.pt"))
    rnn_results = rnn_10ms_test(model, dl, rnn_test_hp,1000)
    del model, dl
    torch.cuda.empty_cache()

    cnn_results = []
    print("Testing CNN...")
    cnn_test_hp = {"frame_size":320, "frame_shift": 160,"num_frames": 200, "batch_size":1}
    dl = DataLoader(ds)
    model = TCNN().to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/cnn_08-05-2025--12-02-34/best_model_1.9495.pt"))
    cnn_results = cnn_10ms_test(model, dl, cnn_test_hp,1000)
    print(cnn_results)
    del model, dl
    torch.cuda.empty_cache()


    all_res = {"gan":gan_results, "crn":crn_results, "rnn":rnn_results, "cnn":cnn_results}
    with open(f"results/{snr_version}_10ms_{datetime_string()}.json","w") as file:
        json.dump(all_res, file)

In [None]:
all_res = {"gan":gan_results, "crn":crn_results, "rnn":rnn_results, "cnn":cnn_results}
with open(f"results/{snr_version}_10ms_{datetime_string()}.json","w") as file:
    json.dump(all_res, file)

In [None]:
def gan_3ms_test(model: Generator, dl: DataLoader, hp: dict, n_samples = None):
    frame_shift = 48
    results = []
    try:
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        z = torch.zeros((1,1024,8)).to(device=device)
        i = 0
        sample: list[torch.Tensor]
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = frame_shift - hp["frame_size"]
            end = frame_shift
            x: torch.Tensor
            y: torch.Tensor
            at_end = False
            perf_time = 0
            j=0
            while not at_end:
                j+=1
                if start < 0:
                    x = torch.nn.functional.pad(sample[0].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,1,-1).to(device=device)
                elif end > sample[0].shape[-1]:
                    x = torch.nn.functional.pad(sample[0].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,1,-1).to(device=device)
                    at_end = True
                else:
                    x = sample[0].squeeze()[start:end].clone().detach().view(1,1,-1).to(device=device)
                    y = sample[1].squeeze()[start:end].clone().detach().view(1,1,-1).to(device=device)
                
                _perf_time = time.perf_counter_ns()
                y_pred: torch.Tensor = model(x, z)
                perf_time += (time.perf_counter_ns() - _perf_time)
                proc_frame_constructor.add_presliced(y_pred[:,:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                clean_frame_constructor.add_presliced(y[:,:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                start += frame_shift
                end += frame_shift
            
            y_pred = proc_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            y = clean_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()

            try:
                sq = calc_pesq(y, y_pred)
                si = calc_stoi(y, y_pred)
                ssnr = calc_snrseg(y, y_pred)
            except NoUtterancesError as e:
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time, "audio_sample_len": sample[0].shape[-1], "frames":j
            }
            results.append(res)
                
    except:
        traceback.print_exc()
    finally:
        return results

def crn_3ms_test(model: ConvBSRU, dl: DataLoader, hp: dict, n_samples = None):
    frame_shift = 48
    results = []
    try:
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        i = 0
        sample: list[torch.Tensor]
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = frame_shift - hp["frame_size"]
            end = frame_shift
            x: torch.Tensor
            y: torch.Tensor
            at_end = False
            perf_time = 0
            j=0
            while not at_end:
                j+=1
                if start < 0:
                    x = torch.nn.functional.pad(sample[0].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,1,-1).to(device=device)
                elif end > sample[0].shape[-1]:
                    x = torch.nn.functional.pad(sample[0].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,1,-1).to(device=device)
                    at_end = True
                else:
                    x = sample[0].squeeze()[start:end].clone().detach().view(1,1,-1).to(device=device)
                    y = sample[1].squeeze()[start:end].clone().detach().view(1,1,-1).to(device=device)
                
                _perf_time = time.perf_counter_ns()
                y_pred: torch.Tensor = model(x)
                perf_time += (time.perf_counter_ns() - _perf_time)
                proc_frame_constructor.add_presliced(y_pred[:,:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                clean_frame_constructor.add_presliced(y[:,:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                start += frame_shift
                end += frame_shift
            
            y_pred = proc_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            y = clean_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()

            try:
                sq = calc_pesq(y, y_pred)
                si = calc_stoi(y, y_pred)
                ssnr = calc_snrseg(y, y_pred)
            except NoUtterancesError as e:
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time, "audio_sample_len": sample[0].shape[-1], "frames":j
            }
            results.append(res)
                
    except:
        traceback.print_exc()
    finally:
        return results

def rnn_3ms_test(model: RHRNet, dl: DataLoader, hp: dict, n_samples = None):
    frame_shift = 48
    results = []
    try:
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        i = 0
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = frame_shift - hp["frame_size"]
            end = frame_shift
            x: torch.Tensor
            y: torch.Tensor
            at_end = False
            perf_time = 0
            j=0
            while not at_end:
                j+=1
                if start < 0:
                    __x = sample[0][:,0:end]
                    x = torch.nn.functional.pad(sample[0].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[0:end].clone().detach(),(abs(start),0),value=0.0).view(1,-1).to(device=device)
                elif end > sample[0].shape[-1]:
                    x = torch.nn.functional.pad(sample[0].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,-1).to(device=device)
                    y = torch.nn.functional.pad(sample[1].squeeze()[start:sample[0].shape[-1]].clone().detach(),(0,end-sample[0].shape[-1]),value=0.0).view(1,-1).to(device=device)
                    at_end = True
                else:
                    x = sample[0].squeeze()[start:end].clone().detach().view(1,-1).to(device=device)
                    y = sample[1].squeeze()[start:end].clone().detach().view(1,-1).to(device=device)
                
                _perf_time = time.perf_counter_ns()
                y_pred: torch.Tensor = model(x)
                perf_time += (time.perf_counter_ns() - _perf_time)
                proc_frame_constructor.add_presliced(y_pred[:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                clean_frame_constructor.add_presliced(y[:,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                start += frame_shift
                end += frame_shift
            
            y_pred = proc_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            y = clean_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()

            try:
                sq = calc_pesq(y, y_pred)
                si = calc_stoi(y, y_pred)
                ssnr = calc_snrseg(y, y_pred)
            except NoUtterancesError as e:
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time, "audio_sample_len": y_pred.shape[-1]
            }
            results.append(res)
                
    except:
        traceback.print_exc()
    finally:
        return results

def cnn_3ms_test(model: TCNN, dl: DataLoader, hp: dict, n_samples = None):
    from collections import deque
    frame_shift = 48
    results = []
    try:
        if n_samples == None:
            n_samples = len(dl)
        model.eval()
        sample: torch.Tensor
        proc_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        clean_frame_constructor = FrameReconstructor(hp["frame_size"], frame_shift, 1)
        i = 0
        for sample in tqdm(dl):
            if i >= n_samples:
                break
            i+=1
            start = frame_shift - hp["frame_size"]
            x: torch.Tensor
            y: torch.Tensor
            at_end = False
            perf_time = 0
            x_pad = torch.nn.functional.pad(sample[0].squeeze(),(abs(start),0),value=0.0)
            y_pad = torch.nn.functional.pad(sample[1].squeeze(),(abs(start),0),value=0.0)
            _,_,pad = calc_windowing(x_pad.shape[0],hp["frame_size"],frame_shift)
            x_pad = torch.nn.functional.pad(x_pad,(0,pad),value=0.0)
            y_pad = torch.nn.functional.pad(y_pad,(0,pad),value=0.0)
            x_frames = get_all_frames(x_pad,hp["frame_size"],frame_shift)
            y_frames = get_all_frames(y_pad,hp["frame_size"],frame_shift)
            
            first_window=0
            windows_count=0
            j=0
            y_pred: torch.Tensor
            while not at_end:
                j+=1
                if first_window+windows_count >= x_frames.shape[1]-1:
                    at_end = True
                # print(j, flush=True)
                windows_count+=1
                if windows_count >= hp["num_frames"]:
                    windows_count = hp["num_frames"]
                    first_window += 1
                x = x_frames.narrow(1,first_window, windows_count)
                x = x.view(1,1,x.shape[1], hp["frame_size"]).to(device=device)
                y = y_frames.narrow(1,first_window, windows_count)
                y = y.view(1,1,y.shape[1], hp["frame_size"]).to(device=device)
                
                _perf_time = time.perf_counter_ns()
                y_pred = model(x)
                proc_frame_constructor.add_presliced(y_pred[0,0,-1,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                clean_frame_constructor.add_presliced(y[0,0,-1,hp["frame_size"]-frame_shift:hp["frame_size"]].clone().detach())
                perf_time += (time.perf_counter_ns() - _perf_time)
            
            y_pred = proc_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            y = clean_frame_constructor.get_current_audio().squeeze().numpy(force=True)
            proc_frame_constructor.reset()
            clean_frame_constructor.reset()

            try:
                sq = calc_pesq(y, y_pred)
                si = calc_stoi(y, y_pred)
                ssnr = calc_snrseg(y, y_pred)
            except NoUtterancesError as e:
                continue
            res = {
                "pesq": sq, "stoi": si, "ssnr":ssnr, "time": perf_time, "audio_sample_len": y_pred.shape[-1], "frames":j
            }
            results.append(res)
                
    except:
        traceback.print_exc()
    finally:
        return results

run_3ms_tests = True
if run_3ms_tests:
    snr_version = "high_snr"
    ds = SortedBatchDataset(get_sequential_wav_paths(f"data/mixed/test/{snr_version}"), get_sequential_wav_paths(f"data/speech_ordered/test/{snr_version}"),1)
    logging.disable(logging.DEBUG)

    gan_results = []
    print("Testing GAN...")
    gan_test_hp = {"frame_size":16384, "frame_shift": 48, "batch_size":1}
    dl = DataLoader(ds)
    model = Generator().to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/gan_29-04-2025--11-58-59/best_checkpoint_1.8107.pt")["gen"])
    gan_results = gan_3ms_test(model, dl, gan_test_hp,300)
    del model, dl
    torch.cuda.empty_cache()

    crn_results = []
    print("Testing CRN...")
    crn_test_hp = {"frame_size":96, "frame_shift": 48, "batch_size":1}
    dl = DataLoader(ds)
    model = ConvBSRU(frame_size=crn_test_hp["frame_size"], conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/crn_20-04-2025--12-49-08/best_model_1.7157.pt"))
    crn_results = crn_3ms_test(model, dl, crn_test_hp,300)
    del model, dl
    torch.cuda.empty_cache()

    rnn_results = []
    print("Testing RNN...")
    rnn_test_hp = {"frame_size":320, "frame_shift": 48, "batch_size":1}
    dl = DataLoader(ds)
    _rnn_hp = load_rnn_hp()
    model = RHRNet(_rnn_hp).to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/rnn_21-04-2025--01-00-43/best_model_2.0291.pt"))
    rnn_results = rnn_3ms_test(model, dl, rnn_test_hp,300)
    del model, dl
    torch.cuda.empty_cache()

    cnn_results = []
    print("Testing CNN...")
    cnn_test_hp = {"frame_size":320, "frame_shift": 48,"num_frames": 300, "batch_size":1}
    dl = DataLoader(ds)
    model = TCNN().to(device=device)
    model.load_state_dict(torch.load(r"saved_models/final/cnn_08-05-2025--12-02-34/best_model_1.9495.pt"))
    cnn_results = cnn_3ms_test(model, dl, cnn_test_hp,300)
    print(cnn_results)
    del model, dl
    torch.cuda.empty_cache()


    all_res = {"gan":gan_results, "crn":crn_results, "rnn":rnn_results, "cnn":cnn_results}
    with open(f"results/{snr_version}_3ms_{datetime_string()}.json","w") as file:
        json.dump(all_res, file)

# Testing

In [17]:
results = {}
with open("results/high_snr_10ms_13-05-2025--20-56-08.json","r") as file:
    results = json.load(file)

for i, r in enumerate(results["rnn"]):
    r["frames"] = results["cnn"][i]["frames"]
    

for mdl, stats in results.items():
    # pesqs = [x["pesq"] for x in stats]
    # stois = [x["stoi"] for x in stats]
    # ssnrs = [x["ssnr"] for x in stats]
    # print(f"PESQ - Mean: {np.mean(pesqs)} | Std: {np.std(pesqs)}")
    # print(f"STOI - Mean: {np.mean(stois)} | Std: {np.std(stois)}")
    # print(f"SSNR - Mean: {np.mean(ssnrs)} | Std: {np.std(ssnrs)}")

    print(mdl)
    pesqs = [x["pesq"] for x in stats]
    stois = [x["stoi"] for x in stats]
    ssnrs = [x["ssnr"] for x in stats]
    perf_time = [x["time"] for x in stats]
    audio_sample_len = [x["audio_sample_len"] for x in stats]
    frames = [x["frames"] for x in stats]
    proc_time = []
    for time, frame_count in zip(perf_time, frames):
        proc_time.append((time / frame_count))


    print(f"Audio Duration - {np.sum(audio_sample_len)//SAMPLE_RATE} seconds")
    print(f"Time Spent processing - {np.sum(perf_time)// (MS_TO_NS * 1000)} seconds")
    print(f"Average frame process time - \\num{{{np.mean(proc_time)/MS_TO_NS} \\pm {np.std(proc_time)/MS_TO_NS}}}")
    print(f"PESQ - \\num{{{np.mean(pesqs)} \\pm {np.std(pesqs)}}}")
    print(f"STOI - \\num{{{np.mean(stois)} \\pm {np.std(stois)}}}")
    print(f"SSNR - \\num{{{np.mean(ssnrs)} \\pm {np.std(ssnrs)}}}")

print(results["gan"][0].keys())



gan
Audio Duration - 12532 seconds
Time Spent processing - 761.0 seconds
Average frame process time - \num{0.6069721550853251 \pm 0.0087605113700968}
PESQ - \num{1.580843438744545 \pm 0.20209189055903545}
STOI - \num{0.909417623193533 \pm 0.040520907140753945}
SSNR - \num{6.55778811408832 \pm 1.5346181403630277}
crn
Audio Duration - 12533 seconds
Time Spent processing - 1173500.0 seconds
Average frame process time - \num{360.8187034415579 \pm 234.49055370293144}
PESQ - \num{1.8796547166109085 \pm 0.2874725416674613}
STOI - \num{0.9247800024814373 \pm 0.038813890293902205}
SSNR - \num{7.178437906680332 \pm 1.8109510368759507}
rnn
Audio Duration - 12540 seconds
Time Spent processing - 2135.0 seconds
Average frame process time - \num{1.7036576129300804 \pm 0.01170476001600824}
PESQ - \num{2.116352467536926 \pm 0.2904823246083804}
STOI - \num{0.9359856942485169 \pm 0.036673957559702396}
SSNR - \num{8.364823564455847 \pm 1.8781794491533674}
cnn
Audio Duration - 12530 seconds
Time Spent proc