In [10]:
#!conda env update --file environment.yml --prune

In [11]:
import torch, torchaudio
import torch.nn as nn
import torchaudio.functional as audioF
from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator, EventEnum
from ignite.metrics import Accuracy, Loss, Metric

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# pd.options.display.max_seq_items = 2000
pd.set_option("display.max_colwidth",None)

import os, random, glob, logging, ntpath, math, time
logging.basicConfig()
logger=logging.getLogger("dbg")
logger.setLevel(logging.DEBUG)
perf_logger=logging.getLogger("perf")
perf_logger.setLevel(logging.DEBUG)
# logging.disable(logging.DEBUG)

from IPython.display import Audio
from tqdm import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.debug(device)
logger.debug(torch.__version__)

DEBUG:dbg:cuda:0
DEBUG:dbg:2.6.0+cu124


# Utility

In [12]:
logger.debug(torchaudio.list_audio_backends())

DEBUG:dbg:['ffmpeg']


In [13]:
import pystoi
import pesq

def combine_audio(speech: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor | int) -> torch.Tensor:
    if not (torch.is_floating_point(speech) or torch.is_complex(speech)):
        # speech = torch.tensor(speech, dtype=torch.float64, device=speech.device)
        speech = speech.to(torch.float64,non_blocking=True)
    if not (torch.is_floating_point(noise) or torch.is_complex(noise)):
        # noise = torch.tensor(noise, dtype=torch.float64, device=noise.device)
        noise = noise.to(torch.float64,non_blocking=True)
    if not(type(snr) is torch.Tensor):
        snr = torch.tensor([snr])
    logger.debug(f"speech:{speech.ndim}, noise:{noise.ndim}, snr:{snr.ndim}")
    out = audioF.add_noise(speech, noise, snr)
    return out

def calc_pesq(speech: np.ndarray, processed: np.ndarray) -> float:
    return pesq.pesq(ref=speech, deg=processed, fs=16000)

def calc_stoi(speech: np.ndarray, processed: np.ndarray) -> float:
    return pystoi.stoi(x=speech, y=processed, fs_sig=16000)

def ns_to_sec(ns: int):
    return ns/1000000000.0

def plot_waveform(waveform, sample_rate=16000):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")

In [53]:
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 4
SHUFFLE = False
FRAME_SHIFT = 40

# ldrnd = random.Random(42)   #   Used for noise loading 

def get_sequential_wav_paths(dir):
    count = len(glob.glob("*.wav", root_dir=dir))
    lst = []
    for i in range(1,count+1):
        lst.append(dir + "/" + str(i) + ".wav")
    
    return lst

class AudioDataset(Dataset):
    def __init__(self, data: list, root_dir: str | None = None):
        self.data = data
        if root_dir==None:
            root_dir = os.getcwd()+"/data"
        self.root_dir = root_dir
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave

    def get(self,idx):
        wave, _ = torchaudio.load(self.root_dir + "/" + self.data[idx], format="wav")
        return wave
    
class SortedSeqShuffledBatch(Dataset):
    def __init__(self, mixed: list, clean: list, batch_size: int = 1):
        self.batch_size = batch_size
        self.mixed = mixed
        self.clean = clean
        # if root_dir==None:
        #     root_dir = os.getcwd()+"/data"
        # self.root_dir = root_dir

    def __len__(self):
        return math.floor(len(self.mixed)/self.batch_size)
    
    def __getitem__(self,idx):
        mixeds = []
        cleans = []
        max_len = 0
        for i in range(self.batch_size):
            if idx + i > len(self.mixed): continue
            mixed_wave, _ = torchaudio.load(self.mixed[idx*self.batch_size+i])
            clean_wave, _ = torchaudio.load(self.clean[idx*self.batch_size+i])
            mixed_wave = mixed_wave[0]
            clean_wave = clean_wave[0]
            assert(mixed_wave.shape[0]==clean_wave.shape[0])
            if i==0:
                max_len = mixed_wave.shape[0]
            else:
                mixed_wave = torch.nn.functional.pad(mixed_wave,(0,max_len-(mixed_wave.shape[0])),value=0.0)
                clean_wave = torch.nn.functional.pad(clean_wave,(0,max_len-(clean_wave.shape[0])),value=0.0)
            # logger.debug(mixed_wave)
            mixeds.append(mixed_wave)
            cleans.append(clean_wave)
        mixeds = np.asarray(mixeds)
        cleans = np.asarray(cleans)
        return torch.tensor(mixeds).unsqueeze(1), torch.tensor(cleans).unsqueeze(1)

    def split(self, val_pct, seed=None):
        rnd = random.Random()
        if seed is not None:
            rnd.seed(seed)
        this_len = len(self)
        val_batches = math.floor(this_len*val_pct)
        val_indices = sorted(rnd.sample(range(this_len), val_batches))
        train_mixed = []
        train_clean = []
        val_mixed = []
        val_clean = []
        for i in range(this_len):
            if i in val_indices:
                for x in range(self.batch_size):
                    val_mixed.append(self.mixed[i*self.batch_size+x])
                    val_clean.append(self.clean[i*self.batch_size+x])
            else:
                for x in range(self.batch_size):
                    train_mixed.append(self.mixed[i*self.batch_size+x])
                    train_clean.append(self.clean[i*self.batch_size+x])
        
        return SortedSeqShuffledBatch(train_mixed,train_clean,self.batch_size), SortedSeqShuffledBatch(val_mixed,val_clean,self.batch_size)


class FrameLoaderEvents(EventEnum):
    END_OF_SAMPLE = "end_of_sample"

class FrameLoader():
    '''Takes a dataloader, frame size and frame shift. It can then be iterated over to produce frames.\n
    Provides padding when sample length would be exceeded.'''

    def __init__(self, dl: DataLoader, frame_size: int, frame_shift: int, engine: Engine | None = None):
        self.iter = iter(dl)
        self.batch_count = len(dl)
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.batch_mixed: torch.Tensor
        self.batch_clean: torch.Tensor
        self.frame_position = 0
        self.at_end = True
        self.engine = engine
    def __iter__(self):
        return self
    def __next__(self):
        if self.at_end:
            batches: tuple[torch.Tensor,torch.Tensor] = next(self.iter)
            self.batch_mixed = batches[0].to(device=device).squeeze_(0)
            self.batch_clean = batches[0].to(device=device).squeeze_(0)
            self.frame_position = 0
            self.at_end = False
            logger.debug(f"mixed batch shape:{self.batch_mixed.shape} | clean batch shape:{self.batch_clean.shape}")
        
        frame_end = self.frame_position + self.frame_size
        frames = []
        for batch_i, batch in enumerate([self.batch_mixed, self.batch_clean]):
            shp = batch.shape
            frame: torch.Tensor
            if frame_end >= shp[2]:
                if self.engine is not None and batch_i==0: 
                    self.engine.fire_event(FrameLoaderEvents.END_OF_SAMPLE)
                    print("emitted FrameLoaderEvents.END_OF_SAMPLE")
                self.at_end = True
                if frame_end != shp[2]:
                    diff = frame_end - shp[2]
                    # Pad batch until aligned with frame_end
                    frame = torch.zeros((BATCH_SIZE,1,self.frame_size),dtype=torch.float32).to(device="cpu")
                    frame[:, 0, 0:self.frame_size - diff] = batch[:, 0, self.frame_position:shp[2]]
                else:
                    frame = torch.tensor(batch[:,:,self.frame_position:frame_end],device=device)
            else:
                frame = torch.tensor(batch[:,:,self.frame_position:frame_end],device=device)
            frames.append(frame)

        
        self.frame_position += self.frame_shift
        # logger.debug(f"FrameLoader: frame_position[{self.frame_position}] - frame_end[{frame_end}]")
        # perf_logger.debug(f"Time to load frame at ({self.frame_position}): {time.perf_counter() - start}s")
        if frames[0].shape[2] != 96:
            logger.debug(frames[0].shape[2])
            logger.debug("FrameLoader issue")
        return frames[0], frames[1]


dataset = SortedSeqShuffledBatch(get_sequential_wav_paths("data/mixed/train"), get_sequential_wav_paths("data/speech_ordered/train"), batch_size=BATCH_SIZE)
train_dataset, val_dataset = dataset.split(0.2,42)
del dataset
base_train_dataloader = DataLoader(train_dataset, shuffle=SHUFFLE)

In [15]:
print(len(train_dataset))

print(len(val_dataset))

20
5


# Models

In [16]:
criterion = nn.MSELoss()
val_metrics : dict[str, Metric]= {
    "accuracy": Accuracy(),
    "loss": Loss(criterion)
}
def register_custom_events(eng: Engine):
    eng.register_events(*FrameLoaderEvents)

def log_trainer_loss(eng: Engine, **kwargs):
    prefix = kwargs.get("prefix","")
    print(f"{prefix}Epoch[{eng.state.epoch}], Iter[{eng.state.iteration}] Loss: {eng.state.output}")

## SEGAN

In [17]:
from models.segan import Generator, Discriminator

def segan_train(models: tuple):
    generator, discriminator = models

## WaveCRN

In [55]:
from models.wavecrn import ConvBSRU
logger.debug("wavecrn loaded")
CRN_FRAME_SIZE = 96

crn_model: ConvBSRU
crn_optimizer: torch.optim.Adam
if 'crn_model' in locals(): del crn_model
if 'crn_optimizer' in locals(): del crn_optimizer

torch.cuda.empty_cache()

crn_model = ConvBSRU(frame_size=CRN_FRAME_SIZE, conv_channels=256, stride=48, num_layers=6, dropout=0.0).to(device=device)
crn_optimizer = torch.optim.Adam(crn_model.parameters(),lr=0.0001)


def wavecrn_train_step(engine, batch):
    # s0 = time.perf_counter()
    crn_model.train()
    crn_optimizer.zero_grad()
    x, y = batch[0].to(device), batch[1].to(device)
    # s1 = time.perf_counter()
    y_pred = crn_model(x)
    # logger.debug(f"x_shape:{x.shape}, y_shape:{y.shape}, y_pred_shape={y_pred.shape}")
    # e1 = time.perf_counter()
    loss = criterion(y_pred, y)
    # s2 = time.perf_counter()
    loss.backward()
    crn_optimizer.step()
    # e2 = time.perf_counter()
    # perf_logger.debug(f"Time for [y_pred = crn_model(x)]:{e1-s1}s | Time for [backprop]:{e2-s2}s")
    # perf_logger.debug(f"Total in [wavecrn_train_step()]: {time.perf_counter()-s0}")
    
    return loss.item()

trainer = Engine(wavecrn_train_step)
register_custom_events(trainer)
trainer.add_event_handler(FrameLoaderEvents.END_OF_SAMPLE,log_trainer_loss,prefix="Sample completed | ")
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=100),log_trainer_loss)

def crn_validation_step(engine, batch):
    crn_model.eval()
    with torch.no_grad():
        x, y = batch[0].to(device), batch[1].to(device)
        y_pred = crn_model(x)
        return y_pred, y

crn_train_evaluator = Engine(crn_validation_step)
crn_val_evaluator = Engine(crn_validation_step)

for name, metric in val_metrics.items():
    metric.attach(crn_train_evaluator, name)

crn_train_dataloader = FrameLoader(base_train_dataloader, CRN_FRAME_SIZE, FRAME_SHIFT, trainer)

trainer.run(crn_train_dataloader, max_epochs=5)

DEBUG:dbg:wavecrn loaded
DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 26239]) | clean batch shape:torch.Size([4, 1, 26239])
  frame = torch.tensor(batch[:,:,self.frame_position:frame_end],device=device)


Epoch[1], Iter[100] Loss: 0.0011000315425917506
Epoch[1], Iter[200] Loss: 0.000707107363268733
Epoch[1], Iter[300] Loss: 0.002963124541565776
Epoch[1], Iter[400] Loss: 0.0003949045203626156
Epoch[1], Iter[500] Loss: 0.0009703459218144417
Epoch[1], Iter[600] Loss: 0.0007356252754107118


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 27840]) | clean batch shape:torch.Size([4, 1, 27840])


Sample completed | Epoch[1], Iter[654] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[700] Loss: 3.5827983083436266e-05
Epoch[1], Iter[800] Loss: 3.115067011094652e-05
Epoch[1], Iter[900] Loss: 9.470204531680793e-05
Epoch[1], Iter[1000] Loss: 0.0037492597475647926
Epoch[1], Iter[1100] Loss: 0.00045707146637141705
Epoch[1], Iter[1200] Loss: 0.0005885479622520506
Epoch[1], Iter[1300] Loss: 2.6310173780075274e-05


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 29520]) | clean batch shape:torch.Size([4, 1, 29520])


Sample completed | Epoch[1], Iter[1349] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[1400] Loss: 7.679520308556675e-07
Epoch[1], Iter[1500] Loss: 3.915267006959766e-05
Epoch[1], Iter[1600] Loss: 0.0005237514851614833
Epoch[1], Iter[1700] Loss: 0.0006559373578056693
Epoch[1], Iter[1800] Loss: 0.00015464953321497887
Epoch[1], Iter[1900] Loss: 0.0001188894675578922
Epoch[1], Iter[2000] Loss: 3.078078952967189e-05


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 29760]) | clean batch shape:torch.Size([4, 1, 29760])


Sample completed | Epoch[1], Iter[2086] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[2100] Loss: 5.091029015602544e-05
Epoch[1], Iter[2200] Loss: 8.707607048563659e-05
Epoch[1], Iter[2300] Loss: 0.00010421356273582205
Epoch[1], Iter[2400] Loss: 0.0007313747773878276
Epoch[1], Iter[2500] Loss: 0.0003358969115652144
Epoch[1], Iter[2600] Loss: 0.00013803022739011794
Epoch[1], Iter[2700] Loss: 7.86840173532255e-05


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 31040]) | clean batch shape:torch.Size([4, 1, 31040])


Epoch[1], Iter[2800] Loss: 5.871110988664441e-05
Sample completed | Epoch[1], Iter[2829] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[2900] Loss: 3.493174153845757e-05
Epoch[1], Iter[3000] Loss: 0.00015691568842157722
Epoch[1], Iter[3100] Loss: 0.00010275563545292243
Epoch[1], Iter[3200] Loss: 4.069020360475406e-05
Epoch[1], Iter[3300] Loss: 2.5575316612957977e-05
Epoch[1], Iter[3400] Loss: 4.474545858101919e-05
Epoch[1], Iter[3500] Loss: 0.00010858091991394758


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 31120]) | clean batch shape:torch.Size([4, 1, 31120])


Epoch[1], Iter[3600] Loss: 2.6534700737101957e-05
Sample completed | Epoch[1], Iter[3604] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[3700] Loss: 2.619422593852505e-05
Epoch[1], Iter[3800] Loss: 7.207015005405992e-05
Epoch[1], Iter[3900] Loss: 0.00010135710908798501
Epoch[1], Iter[4000] Loss: 5.8567449741531163e-05
Epoch[1], Iter[4100] Loss: 6.094941636547446e-05
Epoch[1], Iter[4200] Loss: 2.8670707251876593e-05
Epoch[1], Iter[4300] Loss: 2.376262091274839e-05


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 31680]) | clean batch shape:torch.Size([4, 1, 31680])


Sample completed | Epoch[1], Iter[4381] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[4400] Loss: 2.1262623704387806e-05
Epoch[1], Iter[4500] Loss: 2.4957957066362724e-05
Epoch[1], Iter[4600] Loss: 2.59780772466911e-05
Epoch[1], Iter[4700] Loss: 0.00013097802002448589
Epoch[1], Iter[4800] Loss: 2.5413775802007876e-05
Epoch[1], Iter[4900] Loss: 4.771915700985119e-05
Epoch[1], Iter[5000] Loss: 2.1012485376559198e-05
Epoch[1], Iter[5100] Loss: 1.680511559243314e-05


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 32320]) | clean batch shape:torch.Size([4, 1, 32320])


Sample completed | Epoch[1], Iter[5172] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[5200] Loss: 4.761473974213004e-06
Epoch[1], Iter[5300] Loss: 0.0013585289707407355
Epoch[1], Iter[5400] Loss: 6.900738662807271e-05
Epoch[1], Iter[5500] Loss: 2.54962760664057e-05
Epoch[1], Iter[5600] Loss: 3.4592205338412896e-05
Epoch[1], Iter[5700] Loss: 2.2597781935473904e-05
Epoch[1], Iter[5800] Loss: 6.0001730162184685e-06
Epoch[1], Iter[5900] Loss: 7.294217539310921e-06


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 32640]) | clean batch shape:torch.Size([4, 1, 32640])


Sample completed | Epoch[1], Iter[5979] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[6000] Loss: 8.122426515910774e-05
Epoch[1], Iter[6100] Loss: 0.0001151474061771296
Epoch[1], Iter[6200] Loss: 5.597959898295812e-05
Epoch[1], Iter[6300] Loss: 0.00028324415325187147
Epoch[1], Iter[6400] Loss: 0.0025664721615612507
Epoch[1], Iter[6500] Loss: 0.00011022271792171523
Epoch[1], Iter[6600] Loss: 6.320605461951345e-05
Epoch[1], Iter[6700] Loss: 5.6744611356407404e-05


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 32960]) | clean batch shape:torch.Size([4, 1, 32960])


Sample completed | Epoch[1], Iter[6794] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[6800] Loss: 1.2167943168606143e-05
Epoch[1], Iter[6900] Loss: 6.944942924747011e-06
Epoch[1], Iter[7000] Loss: 9.18142031878233e-05
Epoch[1], Iter[7100] Loss: 1.2609045370481908e-05
Epoch[1], Iter[7200] Loss: 2.1186591766308993e-05
Epoch[1], Iter[7300] Loss: 6.363922580021608e-07
Epoch[1], Iter[7400] Loss: 1.1447742508607917e-05
Epoch[1], Iter[7500] Loss: 6.038317224010825e-06


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 33120]) | clean batch shape:torch.Size([4, 1, 33120])


Epoch[1], Iter[7600] Loss: 3.395905423531076e-07
Sample completed | Epoch[1], Iter[7617] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[7700] Loss: 2.694604290809366e-06
Epoch[1], Iter[7800] Loss: 2.7728515306080226e-06
Epoch[1], Iter[7900] Loss: 4.334217373980209e-05
Epoch[1], Iter[8000] Loss: 0.00014600332360714674
Epoch[1], Iter[8100] Loss: 1.648542092880234e-05
Epoch[1], Iter[8200] Loss: 5.3319581638788804e-06
Epoch[1], Iter[8300] Loss: 7.30755982658593e-06
Epoch[1], Iter[8400] Loss: 3.455497562754317e-06


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 33440]) | clean batch shape:torch.Size([4, 1, 33440])


Sample completed | Epoch[1], Iter[8444] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[8500] Loss: 6.70608415020979e-06
Epoch[1], Iter[8600] Loss: 1.9980907381977886e-05
Epoch[1], Iter[8700] Loss: 5.303139187162742e-05
Epoch[1], Iter[8800] Loss: 4.51028099632822e-05
Epoch[1], Iter[8900] Loss: 1.830649853218347e-05
Epoch[1], Iter[9000] Loss: 1.3304494132171385e-05
Epoch[1], Iter[9100] Loss: 1.355113181489287e-05
Epoch[1], Iter[9200] Loss: 9.220262654707767e-06


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 33920]) | clean batch shape:torch.Size([4, 1, 33920])


Sample completed | Epoch[1], Iter[9279] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[9300] Loss: 1.6015491155485506e-06
Epoch[1], Iter[9400] Loss: 6.24170497758314e-05
Epoch[1], Iter[9500] Loss: 1.4895833828632021e-06
Epoch[1], Iter[9600] Loss: 4.918391368846642e-06
Epoch[1], Iter[9700] Loss: 0.00022882662597112358
Epoch[1], Iter[9800] Loss: 4.5504120862460695e-06
Epoch[1], Iter[9900] Loss: 9.519942068436649e-06
Epoch[1], Iter[10000] Loss: 4.729810825665481e-06
Epoch[1], Iter[10100] Loss: 8.750041047278501e-07
Sample completed | Epoch[1], Iter[10126] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 34240]) | clean batch shape:torch.Size([4, 1, 34240])


Epoch[1], Iter[10200] Loss: 4.476896720007062e-06
Epoch[1], Iter[10300] Loss: 9.488387604505988e-07
Epoch[1], Iter[10400] Loss: 3.7957065615046304e-06
Epoch[1], Iter[10500] Loss: 6.3308962126029655e-06
Epoch[1], Iter[10600] Loss: 4.782023097504862e-06
Epoch[1], Iter[10700] Loss: 2.879063231375767e-06
Epoch[1], Iter[10800] Loss: 6.076113550079754e-06
Epoch[1], Iter[10900] Loss: 2.865821443265304e-06


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 34400]) | clean batch shape:torch.Size([4, 1, 34400])


Sample completed | Epoch[1], Iter[10981] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[11000] Loss: 1.941763548529707e-05
Epoch[1], Iter[11100] Loss: 1.64751909323968e-05
Epoch[1], Iter[11200] Loss: 3.8558075175387785e-05
Epoch[1], Iter[11300] Loss: 2.7902813599212095e-05
Epoch[1], Iter[11400] Loss: 5.841458187205717e-05
Epoch[1], Iter[11500] Loss: 3.718534571817145e-05
Epoch[1], Iter[11600] Loss: 1.53385080920998e-05
Epoch[1], Iter[11700] Loss: 9.318881893705111e-06
Epoch[1], Iter[11800] Loss: 1.1038409866159782e-05


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 34720]) | clean batch shape:torch.Size([4, 1, 34720])


Sample completed | Epoch[1], Iter[11840] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[11900] Loss: 1.3873193438485032e-06
Epoch[1], Iter[12000] Loss: 4.482444182940526e-06
Epoch[1], Iter[12100] Loss: 0.000171233230503276
Epoch[1], Iter[12200] Loss: 4.380198788567213e-06
Epoch[1], Iter[12300] Loss: 3.342254331073491e-06
Epoch[1], Iter[12400] Loss: 4.120838639209978e-05
Epoch[1], Iter[12500] Loss: 6.758532890671631e-06
Epoch[1], Iter[12600] Loss: 8.473303751088679e-06


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 34960]) | clean batch shape:torch.Size([4, 1, 34960])


Epoch[1], Iter[12700] Loss: 3.1295853659685235e-06
Sample completed | Epoch[1], Iter[12707] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[12800] Loss: 8.064283974817954e-06
Epoch[1], Iter[12900] Loss: 0.00015007528418209404
Epoch[1], Iter[13000] Loss: 2.8893998660350917e-06
Epoch[1], Iter[13100] Loss: 5.709644028684124e-06
Epoch[1], Iter[13200] Loss: 7.618530617037322e-06
Epoch[1], Iter[13300] Loss: 3.0240785235946532e-06
Epoch[1], Iter[13400] Loss: 3.970966281485744e-05
Epoch[1], Iter[13500] Loss: 3.0568387501261896e-06


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 35360]) | clean batch shape:torch.Size([4, 1, 35360])


Sample completed | Epoch[1], Iter[13580] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[13600] Loss: 8.165751523847575e-07
Epoch[1], Iter[13700] Loss: 1.1528585673659109e-05
Epoch[1], Iter[13800] Loss: 1.0367303730163258e-05
Epoch[1], Iter[13900] Loss: 6.922839929757174e-06
Epoch[1], Iter[14000] Loss: 2.797126853693044e-06
Epoch[1], Iter[14100] Loss: 6.7311957536730915e-06
Epoch[1], Iter[14200] Loss: 8.518467211615643e-07
Epoch[1], Iter[14300] Loss: 4.993000402464531e-06
Epoch[1], Iter[14400] Loss: 5.842065888828074e-07


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 35600]) | clean batch shape:torch.Size([4, 1, 35600])


Sample completed | Epoch[1], Iter[14463] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[14500] Loss: 1.0470121196703985e-05
Epoch[1], Iter[14600] Loss: 6.502655196527485e-06
Epoch[1], Iter[14700] Loss: 3.495753480819985e-05
Epoch[1], Iter[14800] Loss: 2.2378437279257923e-05
Epoch[1], Iter[14900] Loss: 6.281094101723284e-05
Epoch[1], Iter[15000] Loss: 1.014391273201909e-05
Epoch[1], Iter[15100] Loss: 6.513138941954821e-06
Epoch[1], Iter[15200] Loss: 6.830052370787598e-06
Epoch[1], Iter[15300] Loss: 5.125997176946839e-06


DEBUG:dbg:mixed batch shape:torch.Size([4, 1, 36080]) | clean batch shape:torch.Size([4, 1, 36080])


Sample completed | Epoch[1], Iter[15352] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE
Epoch[1], Iter[15400] Loss: 6.428815026993107e-07
Epoch[1], Iter[15500] Loss: 2.334959845029516e-06
Epoch[1], Iter[15600] Loss: 2.295069680258166e-05
Epoch[1], Iter[15700] Loss: 4.288455784262624e-06
Epoch[1], Iter[15800] Loss: 1.0959573046420701e-05
Epoch[1], Iter[15900] Loss: 1.937245087901829e-06
Epoch[1], Iter[16000] Loss: 1.0017192835221067e-05
Epoch[1], Iter[16100] Loss: 2.285943082824815e-06
Epoch[1], Iter[16200] Loss: 2.1153446994048863e-07
Sample completed | Epoch[1], Iter[16253] Loss: None
emitted FrameLoaderEvents.END_OF_SAMPLE




State:
	iteration: 16254
	epoch: 5
	epoch_length: 16254
	max_epochs: 5
	output: <class 'NoneType'>
	batch: <class 'NoneType'>
	metrics: <class 'dict'>
	dataloader: <class '__main__.FrameLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

## RHR-Net

In [None]:
import yaml
from models.rhrnetdir.Arg_Parser import Recursive_Parse
from models.rhrnet import RHRNet
hp = Recursive_Parse(yaml.load(
    open('models/rhrnetdir/rhrnet_hyperparameters.yaml', encoding='utf-8'),
    Loader=yaml.Loader
    ))  
net = RHRNet(hp)
net(torch.randn(3, 1, 1024))

def rhrnet_train(model):


## Wave-U-Net

In [None]:
from models.waveunet import Waveunet
waveunet_model = Waveunet()

def waveunet_train(model):
    