In [None]:
#!conda env update --file environment.yml --prune

In [None]:
import os, random, glob, logging, ntpath, math, time, sys
from IPython import display
from IPython.display import Audio
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# pd.options.display.max_seq_items = 2000
pd.set_option("display.max_colwidth",None)

import os, random, glob, logging, ntpath, math, time, sys, datetime
logging.basicConfig()
logger=logging.getLogger("dbg")
logging.disable(logging.NOTSET)
logger.setLevel(logging.DEBUG)
perf_logger=logging.getLogger("perf")
perf_logger.setLevel(logging.DEBUG)
# logging.disable(logging.DEBUG)

import torch, torchaudio
import torch.nn as nn
import torchaudio.functional as audioF
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.debug(device)
logger.debug(torch.__version__)
logger.debug(torch.cuda.get_device_name(device))
logger.debug(torchaudio.list_audio_backends())

from ignite.engine import Engine, Events, EventEnum
from ignite.metrics import Loss, Metric, RunningAverage
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
from ignite.exceptions import NotComputableError
from ignite.handlers.tqdm_logger import ProgressBar

In [None]:
BATCH_SIZE = 8
SHUFFLE = False
FRAME_SHIFT = 40
MAXIMUM_SAMPLE_NUM_OF_FRAMES = 640000

RUN_GAN = False
RUN_CRN = True
RUN_RNN = False
RUN_CNN = False

# Utility

In [None]:
from IPython.core.magic import register_cell_magic
from IPython import get_ipython

@register_cell_magic
def skip(line, cell):
    return

@register_cell_magic
def skip_if(line, cell):
    if eval(line):
        return
    get_ipython().run_cell(cell)

In [None]:
import pystoi
import pesq

def combine_audio(speech: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor | int) -> torch.Tensor:
    if not (torch.is_floating_point(speech) or torch.is_complex(speech)):
        # speech = torch.tensor(speech, dtype=torch.float64, device=speech.device)
        speech = speech.to(torch.float64,non_blocking=True)
    if not (torch.is_floating_point(noise) or torch.is_complex(noise)):
        # noise = torch.tensor(noise, dtype=torch.float64, device=noise.device)
        noise = noise.to(torch.float64,non_blocking=True)
    if not(type(snr) is torch.Tensor):
        snr = torch.tensor([snr])
    logger.debug(f"speech:{speech.ndim}, noise:{noise.ndim}, snr:{snr.ndim}")
    out = audioF.add_noise(speech, noise, snr).to(dtype=torch.float)
    return out

def calc_pesq(speech: np.ndarray, processed: np.ndarray) -> float:
    return pesq.pesq(ref=speech, deg=processed, fs=16000)

def calc_stoi(speech: np.ndarray, processed: np.ndarray) -> float:
    return pystoi.stoi(x=speech, y=processed, fs_sig=16000)

def ns_to_sec(ns: int):
    return ns/1000000000.0

def datetime_string():
    return datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")

def plot_waveform(waveform, sample_rate=16000):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")

def write_fstring_file(model_name: str, format_string: str, **args):
    with open(f"saved_models/{model_name}_{datetime_string()}.txt") as f:
        f.write(format_string.format(**args))


# def stitch_audio(frames_func, frame_size, frame_shift):
#     *batches, at_end = frames_func()

#     # audio[mix, clean]
#     audio = [torch.zeros((b.shape[0],b.shape[1],640000),dtype=torch.float) for b in batches if len(b.shape)==3]
#     pos=0
#     end=frame_size
#     for batch_i, batch in enumerate(batches):
#         audio[batch_i][:,:,:end] = batch[:,:,:]

#     frame_slice_start = frame_size - frame_shift
#     while not at_end:
#         pos = end
#         end += frame_shift
#         *batches, at_end = frames_func()
#         for batch_i, batch in enumerate((batches)):
#             audio[batch_i][:,0,pos:end] = batch[:,0,frame_slice_start:]

#             if at_end:
#                 print(audio[batch_i].shape)
#                 audio[batch_i] = torch.tensor(audio[batch_i][:,:,:end]) 
        
#     for i in range(BATCH_SIZE):
#         print(audio[0][i][0])

#     return audio

## Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader

# ldrnd = random.Random(42)   #   Used for noise loading 

def get_sequential_wav_paths(dir):
    count = len(glob.glob("*.wav", root_dir=dir))
    lst = []
    for i in range(1,count+1):
        lst.append(dir + "/" + str(i) + ".wav")
    
    return lst

class AudioDataset(Dataset):
    def __init__(self, data: list, root_dir: str | None = None):
        self.data = data
        if root_dir==None:
            root_dir = os.getcwd()+"/data"
        self.root_dir = root_dir
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave

    def get(self,idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave
    
class SortedBatchDataset(Dataset):
    def __init__(self, mixed: list, clean: list, batch_size: int = 1):
        self.batch_size = batch_size
        self.mixed = mixed
        self.clean = clean
        # if root_dir==None:
        #     root_dir = os.getcwd()+"/data"
        # self.root_dir = root_dir

    def __len__(self):
        return math.floor(len(self.mixed)/self.batch_size)
    
    def __getitem__(self,idx):
        mixeds = []
        cleans = []
        max_len = 0
        for i in range(self.batch_size):
            if idx + i > len(self.mixed): continue
            mixed_wave, _ = torchaudio.load(self.mixed[idx*self.batch_size+i])
            clean_wave, _ = torchaudio.load(self.clean[idx*self.batch_size+i])
            mixed_wave = mixed_wave[0]
            clean_wave = clean_wave[0]
            assert(mixed_wave.shape[0]==clean_wave.shape[0])
            if i==0:
                max_len = mixed_wave.shape[0]
            else:
                mixed_wave = torch.nn.functional.pad(mixed_wave,(0,max_len-(mixed_wave.shape[0])),value=0.0)
                clean_wave = torch.nn.functional.pad(clean_wave,(0,max_len-(clean_wave.shape[0])),value=0.0)
            # logger.debug(mixed_wave)
            mixeds.append(mixed_wave)
            cleans.append(clean_wave)
        mixeds = np.asarray(mixeds)
        cleans = np.asarray(cleans)
        return torch.tensor(mixeds).unsqueeze(1), torch.tensor(cleans).unsqueeze(1)

    def split(self, val_pct, seed=None):
        rnd = random.Random()
        if seed is not None:
            rnd.seed(seed)
        this_len = len(self)
        val_batches = math.floor(this_len*val_pct)
        val_indices = sorted(rnd.sample(range(this_len), val_batches))
        train_mixed = []
        train_clean = []
        val_mixed = []
        val_clean = []
        for i in range(this_len):
            if i in val_indices:
                for x in range(self.batch_size):
                    val_mixed.append(self.mixed[i*self.batch_size+x])
                    val_clean.append(self.clean[i*self.batch_size+x])
            else:
                for x in range(self.batch_size):
                    train_mixed.append(self.mixed[i*self.batch_size+x])
                    train_clean.append(self.clean[i*self.batch_size+x])
        
        return SortedBatchDataset(train_mixed,train_clean,self.batch_size), SortedBatchDataset(val_mixed,val_clean,self.batch_size)

class FrameLoaderEvents(EventEnum):
    END_OF_BATCH = "end_of_batch"

class FrameLoader():
    '''Takes a dataloader, frame size and frame shift. It can then be iterated over to produce frames.\n
    Provides padding when sample length would be exceeded.\n
    Returns (mix, clean, has_batch_ended)'''

    def __init__(self, dl: DataLoader, frame_size: int, frame_shift: int, engine: Engine | None = None, batch_size=BATCH_SIZE):
        self.dl = dl
        self.dl_iter = iter(dl)
        self.batch_count = len(dl)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.batch_size = batch_size
        self.batch_mixed: torch.Tensor
        self.batch_clean: torch.Tensor
        self.frame_position = 0
        self.at_end = True
        self.engine = engine
    def __iter__(self):
        self.dl_iter = iter(self.dl)
        return self
    def __next__(self):
        if self.at_end:
            batches: tuple[torch.Tensor,torch.Tensor] = next(self.dl_iter)
            self.batch_mixed = batches[0].squeeze_(0)
            self.batch_clean = batches[1].squeeze_(0)
            self.frame_position = 0
            self.at_end = False
            # logger.debug(f"mixed batch shape:{self.batch_mixed.shape} | clean batch shape:{self.batch_clean.shape}")
            # mix_maxes = [torch.max(s[0]) for s in self.batch_mixed]
            # clean_maxes = [torch.max(s[0]) for s in self.batch_clean]
            # logger.debug(f"mix_maxes:{str(mix_maxes)} | clean_maxes:{str(clean_maxes)}")
        
        frame_end = self.frame_position + self.frame_size
        frames = []
        for batch_i, batch in enumerate([self.batch_mixed, self.batch_clean]):
            shp = batch.shape
            frame: torch.Tensor
            if frame_end >= shp[2]:
                if self.engine is not None and batch_i==0: 
                    self.engine.fire_event(FrameLoaderEvents.END_OF_BATCH)
                self.at_end = True
                if frame_end != shp[2]:
                    diff = frame_end - shp[2]
                    # Pad batch until aligned with frame_end
                    frame = torch.zeros((self.batch_size ,1,self.frame_size),dtype=torch.float32)
                    frame[:, 0, 0:self.frame_size - diff] = batch[:, 0, self.frame_position:shp[2]]
                else:
                    frame = torch.zeros((self.batch_size ,1,self.frame_size),dtype=torch.float32)
                    frame[:,0,:] = batch[:,0,self.frame_position:frame_end]
                    frame = torch.tensor(batch[:,0,self.frame_position:frame_end])
            else:
                frame = torch.zeros((self.batch_size ,1,self.frame_size),dtype=torch.float32)
                frame[:,0,:] = batch[:,0,self.frame_position:frame_end]
            frames.append(frame)

        self.frame_position += self.frame_shift
        # logger.debug(f"FrameLoader: frame_position[{self.frame_position}] - frame_end[{frame_end}]")
        # perf_logger.debug(f"Time to load frame at ({self.frame_position}): {time.perf_counter() - start}s")
        if frames[0].shape[2] != self.frame_size:
            logger.debug(frames[0].shape[2])
            logger.debug("FrameLoader issue")
        
        return frames[0], frames[1], self.at_end

class FrameReconstructor():
    '''Constructs a batch of audio samples by continuously adding (batches of) frames to the end of a buffer, excluding overlapping sections.\n
    Use `add_frame()` to return the constructed samples, up to the last batch of frames added.
    '''
    def __init__(self, frame_size: int, frame_shift: int,target_shape: torch.Size=(BATCH_SIZE, 1, MAXIMUM_SAMPLE_NUM_OF_FRAMES)):
        self.audio: torch.Tensor = torch.zeros(target_shape,dtype=torch.float)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.pos: int = 0
        self.end: int = frame_size
        self.frame_slice_start: int = 0
        self.at_end = False
    
    def add_frame(self, batch: torch.Tensor, _at_end = False):
        self.audio[:,0,self.pos:self.end] = batch[:,0,self.frame_slice_start:]

        self.pos = self.end
        self.end += self.frame_shift
        self.frame_slice_start = self.frame_size - self.frame_shift

        # if _at_end: 
        #     return True
        # return False

        #   Remove if _at_end ends up doing something
        return _at_end
    
    def get_current_audio(self):
        out = torch.tensor(self.audio[:,0,:self.end - self.frame_shift]).unsqueeze_(1)
        return out

    def reset(self):
        self.audio = torch.zeros(self.audio.shape,dtype=torch.float)
        self.pos = 0
        self.end = self.frame_size
        self.frame_slice_start = 0
        self.at_end = False
    

_dataset = SortedBatchDataset(get_sequential_wav_paths("data/mixed/train"), get_sequential_wav_paths("data/speech_ordered/train"), batch_size=BATCH_SIZE)
train_dataset, val_dataset = _dataset.split(0.2,42)
del _dataset
base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)
base_val_dataloader = DataLoader(val_dataset)

In [None]:
print(len(train_dataset))

print(len(val_dataset))

# Models

In [None]:
criterion = nn.MSELoss()

class PESQMetric(Metric):
    def __init__(self, stitch_keys=("stitch_proc","stitch_clean"), output_transform = lambda x: x, device=device):
        self.stitch_keys=stitch_keys
        self.running_total=0.0
        self.num=0
        super().__init__(output_transform, device)
    @reinit__is_reduced
    def reset(self):
        self.running_total=0.0
        self.num=0
        super().reset()
    @reinit__is_reduced
    def update(self, output):
        if len(output)<=2 or "stitch_proc" not in output[2]: return
        y_pred: np.ndarray
        y: np.ndarray
        y_pred, y = output[2][self.stitch_keys[0]].cpu().numpy(), output[2][self.stitch_keys[1]].cpu().numpy()
        for i in range(BATCH_SIZE):
            self.running_total += calc_pesq(y[i][0], y_pred[i][0])
            self.num += 1
            # print("Processed:")
            # display.display(Audio(y_pred[i][0], rate=16000))
            # print("Clean:")
            # display.display(Audio(y[i][0], rate=16000))
        
    @sync_all_reduce("num","running_total:SUM")
    def compute(self):
        if self.num == 0:
            raise NotComputableError("PESQ Metric must have one complete sample before computing")
        return self.running_total / self.num

class STOIMetric(Metric):
    def __init__(self, stitch_keys=("stitch_proc","stitch_clean"), output_transform = lambda x: x, device=device):
        self.stitch_keys=stitch_keys
        self.running_total=0.0
        self.num=0
        super().__init__(output_transform, device)
    @reinit__is_reduced
    def reset(self):
        self.running_total=0.0
        self.num=0
        super().reset()
    @reinit__is_reduced
    def update(self, output):
        if len(output)<=2 or "stitch_proc" not in output[2]: return
        y_pred: np.ndarray
        y: np.ndarray
        y_pred, y = output[2][self.stitch_keys[0]].cpu().numpy(), output[2][self.stitch_keys[1]].cpu().numpy()
        for i in range(BATCH_SIZE):
            self.running_total += calc_stoi(y[i][0], y_pred[i][0])
            self.num += 1
    @sync_all_reduce("num","running_total:SUM")
    def compute(self):
        if self.num == 0:
            raise NotComputableError("STOI Metric must have one complete sample before computing")
        return self.running_total / self.num


train_metrics: dict[str, Metric] = {
    "loss": Loss(criterion, output_transform=lambda x: (x[0],x[1])),
}
val_metrics: dict[str, Metric] = {
    "loss": Loss(criterion, output_transform=lambda x: (x[0],x[1])),
    "pesq": PESQMetric(),
    "stoi": STOIMetric()
}
def register_custom_events(eng: Engine):
    eng.register_events(*FrameLoaderEvents)

def log_trainer_loss(eng: Engine):
    iterations = eng.state.iteration % eng.state.iteration_ceiling
    print(f"Epoch[{eng.state.epoch}], Iter[{iterations}] Loss: {eng.state.output}")

def log_custom(eng: Engine, **kwargs):
    full_dict = {**eng.state_dict(), "epoch": eng.state.epoch, **kwargs}
    fmt_string: str = kwargs["template"]
    print(fmt_string.format(**full_dict))

def log_eval_results(eng: Engine, **kwargs):
    prefix = kwargs.get("prefix","")
    val_evaluator: Engine = kwargs.get("val_evaluator",None)
    val_frame_loader: FrameLoader = kwargs.get("val_frame_loader",None)
    if val_evaluator is None:
        raise TypeError("log_eval_results must be passed the argument `val_evaluator` of type `Engine`")
    if val_frame_loader is None:
        raise TypeError("log_eval_results must be passed the argument `val_frame_loader` of type `FrameLoader`")
    
    val_evaluator.run(val_frame_loader)
    metrics = val_evaluator.state.metrics
    print(f"{prefix}Epoch[{eng.state.epoch}] | PESQ:[{metrics['pesq']:.2f}] | STOI:[{metrics['stoi']:.2f}] | Loss:[{metrics['loss']}]")

def set_engine_custom_keys(eng: Engine):
    eng.state_dict_user_keys.append("iteration_ceiling")
    eng.state.iteration_ceiling = sys.maxsize

def set_iteration_ceiling(eng: Engine, *args):
    if len(args)==1:
        eng.state.iteration_ceiling = args[0]
    else:
        eng.state.iteration_ceiling = eng.state.iteration

## SEGAN

In [None]:
if RUN_GAN:
    from models.segan import Generator, Discriminator



## WaveCRN

In [None]:
CRN_FRAME_SIZE = 96
CRN_FRAME_SHIFT = 40
CRN_LR = 1.0e-4

CRN_PREP = True
CRN_TRAIN = True
CRN_LOAD = False
CRN_SAVE = True

In [14]:
if CRN_PREP:
    try:
        logging.disable(logging.DEBUG)
        from models.wavecrn import ConvBSRU
        logger.debug("wavecrn loaded")

        crn_model: ConvBSRU
        crn_optimizer: torch.optim.Adam
        if 'crn_model' in locals(): del crn_model
        if 'crn_optimizer' in locals(): del crn_optimizer

        torch.cuda.empty_cache()

        crn_model = ConvBSRU(frame_size=CRN_FRAME_SIZE, conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
        if not CRN_TRAIN:
            crn_model.load_state_dict(torch.load("saved_models/crn.pt", weights_only=True))

        crn_optimizer = torch.optim.Adam(crn_model.parameters(),lr=CRN_LR)

        pf_train_totals = [0,0]                                                     ###
        pf_train_num_loops = 0                                                      ###
        def crn_train_step(engine, batch):
            global pf_train_totals, pf_train_num_loops
            pf_train_forward = time.perf_counter_ns()                               ###
            crn_model.train()
            crn_optimizer.zero_grad()
            x, y = batch[0].to(device), batch[1].to(device)
            y_pred = crn_model(x)
            pf_train_totals[0] += (time.perf_counter_ns() - pf_train_forward)       ###
            pf_train_back = time.perf_counter_ns()                                  ###
            loss = criterion(y_pred, y)
            loss.backward()
            crn_optimizer.step()
            pf_train_totals[1] += (time.perf_counter_ns() - pf_train_back)          ###
            pf_train_num_loops += 1                                                 ###
            return loss.item()

        crn_trainer = Engine(crn_train_step)
        register_custom_events(crn_trainer)
        RunningAverage(output_transform=lambda x: x).attach(crn_trainer,'loss')

        pbar = ProgressBar()
        pbar.attach(crn_trainer,['loss'])

        crn_trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        crn_trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)
        # crn_trainer.add_event_handler(Events.ITERATION_COMPLETED(every=200),log_trainer_loss)
        # crn_trainer.add_event_handler(FrameLoaderEvents.END_OF_BATCH,log_custom,template="Batch complete | Epoch: {epoch}, Iteration: {iteration}")

        crn_train_dataloader = FrameLoader(base_train_dataloader, CRN_FRAME_SIZE, CRN_FRAME_SHIFT, crn_trainer)

        proc_frame_constructor = FrameReconstructor(CRN_FRAME_SIZE, CRN_FRAME_SHIFT)
        clean_frame_constructor = FrameReconstructor(CRN_FRAME_SIZE, CRN_FRAME_SHIFT)

        pf_eval_total = 0                                                           ###
        pf_eval_num_loops = 0                                                       ###
        def crn_val_step(engine, batch):
            global pf_eval_total, pf_eval_num_loops
            pf_eval_forward = time.perf_counter_ns()                                ###
            crn_model.eval()
            with torch.no_grad():
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = crn_model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)         ###
                pf_eval_num_loops += 1                                              ###
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y

        crn_val_evaluator = Engine(crn_val_step)

        # for name, metric in train_metrics.items():
        #     metric.attach(crn_train_evaluator, name)
        for name, metric in val_metrics.items():
            metric.attach(crn_val_evaluator, name)

        crn_val_dataloader = FrameLoader(base_val_dataloader, CRN_FRAME_SIZE, CRN_FRAME_SHIFT)
        crn_trainer.add_event_handler(Events.EPOCH_COMPLETED,log_eval_results,val_evaluator=crn_val_evaluator,val_frame_loader=crn_val_dataloader)

        if CRN_TRAIN:
            epchs=10
            crn_trainer.run(crn_train_dataloader, max_epochs=epchs)
            a=pf_train_totals[0] / float(pf_train_num_loops)                        ###
            b=pf_train_totals[1] / float(pf_train_num_loops)                        ###
            c=pf_eval_total / float(pf_eval_num_loops)                              ###
            print(f"train forward:{a:.20f} | backprop:{b:.20f} | eval forward:{c:.20f}")
        
        if CRN_SAVE:
            torch.save(crn_model.state_dict(),f"saved_models/crn_{datetime_string()}.pt")
            write_fstring_file("crn","lr:{lr},epochs:{epochs}",lr=CRN_LR,epochs=epchs)


    finally:
        logging.disable(logging.NOTSET)

ERROR:ignite.engine.engine.Engine:Engine run is terminating due to exception: 


KeyboardInterrupt: 

In [None]:
a=ns_to_sec(pf_train_totals[0] / float(pf_train_num_loops))                         ###
b=ns_to_sec(pf_train_totals[1] / float(pf_train_num_loops))                         ###
c=ns_to_sec(pf_eval_total / float(pf_train_num_loops/4))                            ###
print(f"train forward:{a:.20f} | backprop:{b:.20f} | eval forward:{c:.20f}")

In [15]:
if 'crn_model' in locals(): torch.save(crn_model.state_dict(),f"saved_models/crn_{datetime_string()}.pt")

## RHR-Net

In [None]:
RNN_FRAME_SIZE = 96
RNN_FRAME_SHIFT = 40
RNN_LR = 1.0e-5

RNN_PREP = True
RNN_TRAIN = True
RNN_LOAD = False
RNN_SAVE = False

In [None]:
if RNN_PREP:
    try:
        logging.disable(logging.DEBUG)
        logger.debug("rhrnet loaded")

        import yaml
        from models.rhrnetdir.Arg_Parser import Recursive_Parse
        from models.rhrnet import RHRNet
        rnn_hp = Recursive_Parse(yaml.load(
            open('models/rhrnetdir/rhrnet_hyperparameters.yaml', encoding='utf-8'),
            Loader=yaml.Loader
            ))  
        rnn_model: RHRNet
        rnn_optimizer: torch.optim.Adam
        if 'rnn_model' in locals(): del rnn_model
        if 'rnn_optimizer' in locals(): del rnn_optimizer

        torch.cuda.empty_cache()

        rnn_model = RHRNet(rnn_hp).to(device=device)
        if RNN_LOAD:
            rnn_model.load_state_dict(torch.load("saved_models/rnn.pt", weights_only=True))

        rnn_optimizer = torch.optim.Adam(rnn_model.parameters(),lr=RNN_LR)

        pf_train_totals = [0,0]
        pf_train_num_loops = 0
        def rnn_train_step(engine, batch):
            global pf_train_totals, pf_train_num_loops
            pf_train_forward = time.perf_counter_ns()
            rnn_model.train()
            rnn_optimizer.zero_grad()
            x, y = batch[0].squeeze_(1).to(device), batch[1].squeeze_(1).to(device)
            y_pred = rnn_model(x)
            pf_train_totals[0] += (time.perf_counter_ns() - pf_train_forward)
            pf_train_back = time.perf_counter_ns()
            loss = criterion(y_pred, y)
            loss.backward()
            rnn_optimizer.step()
            pf_train_totals[1] += (time.perf_counter_ns() - pf_train_back)
            pf_train_num_loops += 1
            return loss.item()

        rnn_trainer = Engine(rnn_train_step)
        register_custom_events(rnn_trainer)
        RunningAverage(output_transform=lambda x: x).attach(rnn_trainer,'loss')

        pbar = ProgressBar()
        pbar.attach(rnn_trainer,['loss'])

        rnn_trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        rnn_trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)
        # rnn_trainer.add_event_handler(Events.ITERATION_COMPLETED(every=200),log_trainer_loss)
        # rnn_trainer.add_event_handler(FrameLoaderEvents.END_OF_BATCH,log_custom,template="Batch complete | Epoch: {epoch}, Iteration: {iteration}")

        rnn_train_dataloader = FrameLoader(base_train_dataloader, RNN_FRAME_SIZE, RNN_FRAME_SHIFT, rnn_trainer)

        proc_frame_constructor = FrameReconstructor(RNN_FRAME_SIZE, RNN_FRAME_SHIFT)
        clean_frame_constructor = FrameReconstructor(RNN_FRAME_SIZE, RNN_FRAME_SHIFT)

        pf_eval_total = 0
        pf_eval_num_loops = 0
        def rnn_val_step(engine, batch):
            global pf_eval_total, pf_eval_num_loops
            pf_eval_forward = time.perf_counter_ns()
            rnn_model.eval()
            with torch.no_grad():
                x, y = batch[0].squeeze_(1).to(device), batch[1].squeeze_(1).to(device)
                y_pred = rnn_model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)
                pf_eval_num_loops += 1
                proc_frame_constructor.add_frame(y_pred.unsqueeze_(1))
                clean_frame_constructor.add_frame(y.unsqueeze_(1))
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y

        rnn_val_evaluator = Engine(rnn_val_step)

        # for name, metric in train_metrics.items():
        #     metric.attach(rnn_train_evaluator, name)
        for name, metric in val_metrics.items():
            metric.attach(rnn_val_evaluator, name)

        rnn_val_dataloader = FrameLoader(base_val_dataloader, RNN_FRAME_SIZE, RNN_FRAME_SHIFT)
        rnn_trainer.add_event_handler(Events.EPOCH_COMPLETED,log_eval_results,val_evaluator=rnn_val_evaluator,val_frame_loader=rnn_val_dataloader)

        if RNN_TRAIN:
            rnn_trainer.run(rnn_train_dataloader, max_epochs=1)
            a=pf_train_totals[0] / float(pf_train_num_loops)
            b=pf_train_totals[1] / float(pf_train_num_loops)
            c=pf_eval_total / float(pf_eval_num_loops)
            print(f"train forward:{a:.20f} | backprop:{b:.20f} | eval forward:{c:.20f}")

        if RNN_SAVE:
            torch.save(rnn_model.state_dict(),f"saved_models/rnn_{datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")}.pt")


    finally:
        logging.disable(logging.NOTSET)

In [None]:
# if 'rnn_model' in locals(): torch.save(rnn_model.state_dict(),f"saved_models/rnn_{datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")}.pt")

## Wave-U-Net

In [None]:
CNN_FRAME_SIZE = 16153
CNN_OUT_FRAME_SIZE = 16009
CNN_FRAME_SHIFT = CNN_FRAME_SIZE / 4
CNN_LR = 1.0e-4

CNN_PREP = True
CNN_TRAIN = False
CNN_LOAD= False
CNN_SAVE = False

In [None]:
if CNN_PREP:
    try:
        logging.disable(logging.DEBUG)
        from models.waveunet import Waveunet
 
        cnn_model: RHRNet
        cnn_optimizer: torch.optim.Adam
        if 'cnn_model' in locals(): del cnn_model
        if 'cnn_optimizer' in locals(): del cnn_optimizer

        torch.cuda.empty_cache()
        from types import SimpleNamespace
        args = SimpleNamespace(
            features = 32, 
            instruments =  [], 
            res =  "fixed", 
            separate =  0, 
            channels =  1, 
            kernel_size =  5, 
            levels =  3, 
            depth = 1, 
            feature_growth =  "double",
            strides = 4,
            conv_type = "gn",
            sr =  16000,
            output_size =  1.0
        )
        num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
                   [args.features*2**i for i in range(0, args.levels)]
        target_outputs = int(args.output_size * args.sr)
        cnn_model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size,
                        target_output_size=target_outputs, depth=args.depth, strides=args.strides,
                        conv_type=args.conv_type, res=args.res, separate=args.separate).to(device=device)
        if CNN_LOAD:
            cnn_model.load_state_dict(torch.load("saved_models/cnn.pt", weights_only=True))

        cnn_optimizer = torch.optim.Adam(cnn_model.parameters(),lr=CNN_LR)

        pf_train_totals = [0,0]
        pf_train_num_loops = 0
        def cnn_train_step(engine, batch):
            pf_train_forward = time.perf_counter_ns()
            cnn_model.train()
            cnn_optimizer.zero_grad()
            x, y = batch[0].to(device), batch[1].to(device)
            y_pred = cnn_model(x)
            pf_train_totals[0] += (time.perf_counter_ns() - pf_train_forward)
            pf_train_back = time.perf_counter_ns()
            loss = criterion(y_pred, y)
            loss.backward()
            cnn_optimizer.step()
            pf_train_totals[1] += (time.perf_counter_ns() - pf_train_back)
            pf_train_num_loops += 1
            return loss.item()

        cnn_trainer = Engine(cnn_train_step)
        register_custom_events(cnn_trainer)
        RunningAverage(output_transform=lambda x: x).attach(cnn_trainer,'loss')

        pbar = ProgressBar()
        pbar.attach(cnn_trainer,['loss'])

        cnn_trainer.add_event_handler(Events.STARTED, set_engine_custom_keys)
        cnn_trainer.add_event_handler(Events.EPOCH_COMPLETED(once=1),set_iteration_ceiling)
        # cnn_trainer.add_event_handler(Events.ITERATION_COMPLETED(every=200),log_trainer_loss)
        # cnn_trainer.add_event_handler(FrameLoaderEvents.END_OF_BATCH,log_custom,template="Batch complete | Epoch: {epoch}, Iteration: {iteration}")

        cnn_train_dataloader = FrameLoader(base_train_dataloader, CNN_FRAME_SIZE, CNN_FRAME_SHIFT, cnn_trainer)

        proc_frame_constructor = FrameReconstructor(CNN_OUT_FRAME_SIZE, CNN_FRAME_SHIFT)
        clean_frame_constructor = FrameReconstructor(CNN_OUT_FRAME_SIZE, CNN_FRAME_SHIFT)

        pf_eval_total = 0
        pf_eval_num_loops = 0
        def cnn_val_step(engine, batch):
            pf_eval_forward = time.perf_counter_ns()
            cnn_model.eval()
            with torch.no_grad():
                x, y = batch[0].to(device), batch[1].to(device)
                y_pred = cnn_model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)
                proc_frame_constructor.add_frame(y_pred)
                clean_frame_constructor.add_frame(y)
                if batch[2]:    #   Frame fully constructed
                    y_pred_stitch = proc_frame_constructor.get_current_audio()
                    y_stitch = clean_frame_constructor.get_current_audio()
                    proc_frame_constructor.reset()
                    clean_frame_constructor.reset()
                    return y_pred, y, {"stitch_proc": y_pred_stitch, "stitch_clean": y_stitch}

                return y_pred, y

        cnn_val_evaluator = Engine(cnn_val_step)

        # for name, metric in train_metrics.items():
        #     metric.attach(cnn_train_evaluator, name)
        for name, metric in val_metrics.items():
            metric.attach(cnn_val_evaluator, name)

        cnn_val_dataloader = FrameLoader(base_val_dataloader, CNN_FRAME_SIZE, CNN_FRAME_SHIFT)
        cnn_trainer.add_event_handler(Events.EPOCH_COMPLETED,log_eval_results,val_evaluator=cnn_val_evaluator,val_frame_loader=cnn_val_dataloader)

        if CNN_TRAIN:
            cnn_trainer.run(cnn_train_dataloader, max_epochs=10)

        if CNN_SAVE:
            torch.save(cnn_model.state_dict(),f"saved_models/cnn_{datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")}.pt")

    finally:
        logging.disable(logging.NOTSET)

In [None]:
# if 'cnn_model' in locals(): torch.save(cnn_model.state_dict(),f"saved_models/cnn_{datetime.datetime.now().strftime("%d-%m-%Y--%H-%M-%S")}.pt")

# Running Models 

## Model Running Functions 

In [None]:
class FakeDataset(Dataset):
    def __init__(self, l_waveforms: list, r_waveforms:list):
        self.l_waveforms = l_waveforms
        self.r_waveforms = r_waveforms
        super().__init__()
    def __len__(self):
        return len(self.l_waveforms)
    def __getitem__(self,idx):
        # return self.l_waveforms[idx], self.r_waveforms[idx]
        # print(torch.tensor(self.l_waveforms[idx]).unsqueeze_(0))
        # print(torch.tensor(self.l_waveforms[idx]).unsqueeze_(0).shape,flush=True)

        return torch.tensor(self.l_waveforms[idx]).unsqueeze_(0), torch.tensor(self.r_waveforms[idx]).unsqueeze_(0)

def evaluate_e2e_one_sample(model_dict: dict):
    chosen_sample = ntpath.basename(random.choice(glob.glob("data/speech_ordered/train/*.wav")))
    clean_sample,_ = torchaudio.load("data/speech_ordered/train/" + chosen_sample)
    mixed_sample,_ = torchaudio.load("data/mixed/train/" + chosen_sample)
    ds = FakeDataset([mixed_sample],[clean_sample])
    dl = DataLoader(ds)
    out = {}
    for name in model_dict.keys():
        out[name] = {}
        out[name]["model"] = model_dict[name]
        match name:
            case "cnn":
                out[name]["loader"] = FrameLoader(dl,CNN_FRAME_SIZE,CNN_FRAME_SHIFT, batch_size=1)
                out[name]["processed_constructor"] = FrameReconstructor(CNN_OUT_FRAME_SIZE,CNN_FRAME_SHIFT, (1,1,MAXIMUM_SAMPLE_NUM_OF_FRAMES))
                out[name]["clean_constructor"] = FrameReconstructor(CNN_OUT_FRAME_SIZE,CNN_FRAME_SHIFT, (1,1,MAXIMUM_SAMPLE_NUM_OF_FRAMES))
                out[name]["in_transform"] = lambda x: x
                out[name]["out_transform"] = lambda x: x
            case "rnn":
                out[name]["loader"] = FrameLoader(dl,RNN_FRAME_SIZE,RNN_FRAME_SHIFT, batch_size=1)
                out[name]["processed_constructor"] = FrameReconstructor(RNN_FRAME_SIZE,RNN_FRAME_SHIFT, (1,1,MAXIMUM_SAMPLE_NUM_OF_FRAMES))
                out[name]["clean_constructor"] = FrameReconstructor(RNN_FRAME_SIZE,RNN_FRAME_SHIFT, (1,1,MAXIMUM_SAMPLE_NUM_OF_FRAMES))
                out[name]["in_transform"] = lambda x: x.squeeze_(1)
                out[name]["out_transform"] = lambda x: x.unsqueeze_(1)
            case "crn":
                out[name]["loader"] = FrameLoader(dl,CRN_FRAME_SIZE,CRN_FRAME_SHIFT, batch_size=1)
                out[name]["processed_constructor"] = FrameReconstructor(CRN_FRAME_SIZE,CRN_FRAME_SHIFT, (1,1,MAXIMUM_SAMPLE_NUM_OF_FRAMES))
                out[name]["clean_constructor"] = FrameReconstructor(CRN_FRAME_SIZE,CRN_FRAME_SHIFT, (1,1,MAXIMUM_SAMPLE_NUM_OF_FRAMES))
                out[name]["in_transform"] = lambda x: x
                out[name]["out_transform"] = lambda x: x
            # case "gan":
            #     out[name]["loader"] = FrameLoader(dl,GAN_FRAME_SIZE,GAN_FRAME_SHIFT)
            case _:
                pass
        out[name]["perf"] = {"e2e_time":0, "avg_forward":0}
    # print("shape__:" + str(clean_sample.shape))
    out["num_frames"] = clean_sample.shape[1]
    
    for name in model_dict.keys():
        # print(name)
        # print(out[name],flush=True)
        t_in: function = out[name]["in_transform"]
        t_out: function = out[name]["out_transform"]
        at_end = False
        model = out[name]["model"]
        data = iter(out[name]["loader"])
        pf_eval_total = 0
        n_loops = 0
        pf_eval_e2e = time.perf_counter_ns()
        while not at_end:
            pf_eval_forward = time.perf_counter_ns()
            model.eval()
            with torch.no_grad():
                x, y, at_end = next(data)
                x = t_in(x.to(device=device))
                y = t_in(y.to(device=device))
                y_pred = model(x)
                pf_eval_total += (time.perf_counter_ns() - pf_eval_forward)
                n_loops += 1
                try:
                    out[name]["processed_constructor"].add_frame(t_out(y_pred))
                    out[name]["clean_constructor"].add_frame(t_out(y))
                except RuntimeError as e:
                    print(out[name]["processed_constructor"].audio.shape)
                    raise RuntimeError(e.args)
                

                if at_end:    #   Frame fully constructed
                    y_pred_stitch = out[name]["processed_constructor"].get_current_audio()
                    y_stitch = out[name]["clean_constructor"].get_current_audio()
                    print(y_pred_stitch.shape)
                    print(y_stitch.shape)
                    display.display(Audio(y_pred_stitch,rate=16000))
                    display.display(Audio(y_stitch,rate=16000))
                    
        
        pf_eval_e2e = time.perf_counter_ns() - pf_eval_e2e
        out[name]["perf"]["e2e_time"] = pf_eval_e2e
        out[name]["perf"]["avg_forward"] = pf_eval_total / float(n_loops)
    
    return out

In [20]:
crn_model = ConvBSRU(frame_size=CRN_FRAME_SIZE, conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
crn_model.load_state_dict(torch.load("saved_models/crn.pt", weights_only=True))
models = {  "crn": crn_model }#, "rnn": rnn_model}#, "cnn": cnn_model,}
out = evaluate_e2e_one_sample(models)

print("Duration of audio: " + str(out["num_frames"] / 16000.0))
for name in models.keys():
    model_dict = out[name]
    print(f"{name.upper()} -- e2e:{ns_to_sec(model_dict["perf"]["e2e_time"])}, average forward:{ns_to_sec(model_dict["perf"]["avg_forward"])}")

  return torch.tensor(self.l_waveforms[idx]).unsqueeze_(0), torch.tensor(self.r_waveforms[idx]).unsqueeze_(0)
  out = torch.tensor(self.audio[:,0,:self.end - self.frame_shift]).unsqueeze_(1)


ValueError: Array audio input must be a 1D or 2D array