In [20]:
!pip install torchinfo
!pip install mne
!pip install scipy==1.11.1
!pip install omegaconf
!pip install xlstm
!pip install dacite



In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import from_numpy as np2TT
from torchinfo import summary

from os.path import expanduser
from scipy.io import savemat
import numpy as np
import math
import json
import time
import mne
import sys
import os
import wandb
import matplotlib.pyplot as plt
from datetime import datetime
import yaml
from pathlib import Path
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from scipy.io import loadmat
from scipy import signal
from matplotlib.colors import rgb2hex
from sklearn.model_selection import train_test_split

from xlstm import (
    xLSTMBlockStack,
    xLSTMBlockStackConfig,
    mLSTMBlockConfig,
    mLSTMLayerConfig,
    sLSTMBlockConfig,
    sLSTMLayerConfig,
    FeedForwardConfig,
)

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: " + str(device))

Device: cuda


In [23]:
use_google_drive = True
if use_google_drive:
    from google.colab import drive
    drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Proposed model


In [24]:
class LSTMConvAutoencoder3(nn.Module):
    def __init__(self, input_dim=18, num_layers = 1):
        super(LSTMConvAutoencoder3, self).__init__()

        self.lstm = nn.LSTM(256, 256, num_layers, batch_first=True)

        # Encoder: 3 1D convolutional layers
        self.encoder = nn.Sequential(
            nn.Conv1d(input_dim, 64, kernel_size=5, stride=2, padding=2),  # [batch_size, 64, 256]
            #nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2),             # [batch_size, 128, 128]
            #nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2),            # [batch_size, 256, 64]
            #nn.ReLU()
        )

        # Decoder: 3 1D transposed convolutional layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),  # [batch_size, 128, 128]
            #nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),   # [batch_size, 64, 256]
            #nn.ReLU(),
            nn.ConvTranspose1d(64, input_dim, kernel_size=5, stride=2, padding=2, output_padding=1),  # [batch_size, input_channels, 512]
            #nn.Sigmoid()  # Sigmoid for output normalization between 0 and 1
        )

    def forward(self, x):
        x = x.squeeze(1)

        # Forward pass through encoder
        x = self.encoder(x)                #encoder takes [batch_size, input_dim, sequence_length]
        x, _ = self.lstm(x.permute(0,2,1))
        # Forward pass through decoder
        x = self.decoder(x.permute(0,2,1))

        return x.unsqueeze(1)

# CLEEGN

In [25]:
class Permute2d(nn.Module):
    def __init__(self, shape):
        super(Permute2d, self).__init__()
        self.shape = shape

    def forward(self, x):
        return torch.permute(x, self.shape)

In [26]:
class CLEEGN(nn.Module):
    def __init__(self, n_chan, fs, N_F=20, tem_kernelLen=0.1):
        super(CLEEGN,self).__init__()
        self.n_chan = n_chan
        self.N_F = N_F
        self.fs = fs
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, n_chan, (n_chan, 1), padding="valid", bias=True),
            Permute2d((0, 2, 1, 3)),
            nn.BatchNorm2d(1, eps=1e-3, momentum=0.99)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(1, N_F, (1, int(fs * tem_kernelLen)), padding="same", bias=True),
            nn.BatchNorm2d(N_F, eps=1e-3, momentum=0.99)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(N_F, N_F, (1, int(fs * tem_kernelLen)), padding="same", bias=True),
            nn.BatchNorm2d(N_F, eps=1e-3, momentum=0.99)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(N_F, n_chan, (n_chan, 1), padding="same", bias=True),
            nn.BatchNorm2d(n_chan, eps=1e-3, momentum=0.99)
        )
        self.conv5 = nn.Conv2d(n_chan, 1, (n_chan,1), padding="same", bias=True)

    def forward(self, x):
        # encoder
        x = self.conv1(x)
        x = self.conv2(x)
        # decoder
        x = self.conv3(x)
        x = self.conv4(x)

        x = self.conv5(x)
        return x

# 1D-ResCNN


In [27]:
class Res_BasicBlock(nn.Module):
    def __init__(self, kernelsize, stride=1):
        super(Res_BasicBlock, self).__init__()
        self.bblock = nn.Sequential(
            nn.Conv1d(32, 32, kernel_size=kernelsize, stride=stride, padding='same'),
            nn.BatchNorm1d(32),
            #nn.ReLU(),
            nn.Conv1d(32, 16, kernel_size=kernelsize, stride=1, padding='same'),
            nn.BatchNorm1d(16),
            #nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=kernelsize, stride=1, padding='same'),
            nn.BatchNorm1d(32),
            #nn.ReLU()
        )

    def forward(self, x):
        out = self.bblock(x)
        identity = x
        output = out + identity  # element-wise addition for skip connection
        return output

class BasicBlockall(nn.Module):
    def __init__(self):
        super(BasicBlockall, self).__init__()
        self.bblock3 = nn.Sequential(Res_BasicBlock(3), Res_BasicBlock(3))
        self.bblock5 = nn.Sequential(Res_BasicBlock(5), Res_BasicBlock(5))
        self.bblock7 = nn.Sequential(Res_BasicBlock(7), Res_BasicBlock(7))

    def forward(self, x):
        out3 = self.bblock3(x)
        out5 = self.bblock5(x)
        out7 = self.bblock7(x)
        out = torch.cat([out3, out5, out7], dim=1)  # concatenate along channel dimension
        return out

class OneD_ResCNN(nn.Module):
    def __init__(self, seq_length, batch_size, n_chan):
        super(OneD_ResCNN, self).__init__()
        self.batch_size = batch_size
        self.n_chan = n_chan

        self.initial_conv = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=5, stride=1, padding='same'),
            nn.BatchNorm1d(32),
            #nn.ReLU()
        )
        self.basic_block = BasicBlockall()
        self.final_conv = nn.Sequential(
            nn.Conv1d(96, 1, kernel_size=1, stride=1, padding='same'),  # Adjust channels after concatenation
            nn.BatchNorm1d(1),
            #nn.ReLU()
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(seq_length, seq_length)

    def forward(self, x):
        #x = x.squeeze()
        #x = x.unsqueeze(1)
        x = x.squeeze(1)
        x = x.view(self.batch_size*self.n_chan, 1, -1)
        x = self.initial_conv(x)
        x = self.basic_block(x)
        x = self.final_conv(x)
        x = self.flatten(x)
        x = self.fc(x)
        x = x.view(self.batch_size, 1, self.n_chan, -1)
        return x


# IC-U_Net

In [28]:
class CBR(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        padding = int((kernel_size - 1) / 2)

        self.double_conv = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

class IC_U_NET(nn.Module):
    def __init__(self, input_channels=18):
        super(IC_U_NET, self).__init__()

        self.enc1 = CBR(input_channels, 64, 7)
        self.enc2 = CBR(64, 128, 7)
        self.enc3 = CBR(128, 256, 5)
        self.enc4 = CBR(256, 512, 3)
        self.dec1 = CBR(512, 256, 3)
        self.dec2 = CBR(256, 128, 3)
        self.dec3 = CBR(128, 64, 3)
        self.dec4 = nn.Sequential(
            nn.Conv1d(128, 64, kernel_size=1, padding=0),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Conv1d(64, input_channels, kernel_size=1, padding=0),
            nn.BatchNorm1d(input_channels),
            nn.ReLU(inplace=True),
        )
        self.maxPool = nn.MaxPool1d(2)
        self.tConv1 = nn.ConvTranspose1d(512, 512, kernel_size=2, stride=2)
        self.tConv2 = nn.ConvTranspose1d(512, 256, kernel_size=2, stride=2)
        self.tConv3 = nn.ConvTranspose1d(256, 128, kernel_size=2, stride=2)


        # Encoder: 3 1D convolutional layers
        self.encoder = nn.Sequential(
            self.enc1,
            nn.MaxPool1d(2),
            self.enc2,
            nn.MaxPool1d(2),
            self.enc3,
            nn.MaxPool1d(2),
            self.enc4,
        )



        # Decoder: 3 1D transposed convolutional layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(512, 512, kernel_size=2, stride=2),  # [batch_size, 128, 128]
            self.dec1,
            nn.ConvTranspose1d(256, 256, kernel_size=2, stride=2),   # [batch_size, 64, 256]
            self.dec2,
            nn.ConvTranspose1d(128, 128, kernel_size=2, stride=2),  # [batch_size, input_channels, 512]
            self.dec3,
            self.dec4,
        )

    def forward(self, x):
        x = x.squeeze(1)

        #encoder
        skip1 = self.enc1(x)
        x = self.maxPool(skip1)
        skip2 = self.enc2(x)
        x = self.maxPool(skip2)
        skip3 = self.enc3(x)
        x = self.maxPool(skip3)
        x = self.enc4(x)

        #decoder
        x = self.tConv1(x)
        x1 = self.dec1(x)
        x = torch.cat([x1, skip3], dim=1)
        x = self.tConv2(x)
        x2 = self.dec2(x)
        x = torch.cat([x2, skip2], dim=1)
        x = self.tConv3(x)
        x3 = self.dec3(x)
        x = torch.cat([x3, skip1], dim=1)
        x = self.dec4(x)
        return x.unsqueeze(1)

# Dataset

In [29]:
def segment_eeg(eeg, window_size=100, stride=50):
    """ Session EEG Signal by Slinding Window """
    n_chan, n_timep = eeg.shape
    tstamps, segments = [], []
    for i in range(0, n_timep, stride):
        seg = eeg[:,i: i + window_size]
        if seg.shape != (n_chan, window_size):
            break
        segments.append(seg)
        tstamps.append(i)

    return segments, tstamps

def create_dataset(x_fpath, y_fpath, fmt_terms, tmin, tmax, ch_names=None, win_size=4, stride=2):
    """ read mne set to numpy array """
    x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_terms[0]), verbose=0)
    sfreq = x_raw.info["sfreq"]
    win_size = math.ceil(win_size * sfreq)
    stride = math.ceil(stride * sfreq)
    nb_chan = len(x_raw.ch_names if ch_names is None else ch_names)
    tmin = math.ceil(tmin * sfreq)
    tmax = math.ceil(tmax * sfreq)

    X = np.zeros((0, nb_chan, win_size), dtype=np.float32)
    y = np.zeros((0, nb_chan, win_size), dtype=np.float32)
    for fmt_term in fmt_terms:
        x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
        y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
        if len(x_raw.ch_names) != len(y_raw.ch_names):
            raise ValueError(f"EEG channel should be matched, found {len(x_raw.ch_names)} and {len(y_raw.ch_names)}")

        x_content = x_raw[:, tmin: tmax][0]
        y_content = y_raw[:, tmin: tmax][0]

        #np.shape(x_content)
        #plt.hist(x_content*1e+6)
        #plt.show()

        if ch_names is not None:  # channel selection
            picks = [x_raw.ch_names.index(ch) for ch in ch_names]
            x_content = x_content[picks]
            y_content = y_content[picks]

        x_seg = np.array(segment_eeg(x_content, win_size, stride)[0])
        y_seg = np.array(segment_eeg(y_content, win_size, stride)[0])
        X = np.append(X, np.array(x_seg), axis=0)
        y = np.append(y, np.array(y_seg), axis=0)
    return X*1e+6, y*1e+6    #TODO

def calc_SNR(clean_data, noisy_data, inDezibel = True):
    # clean data: reference data
    # noisy data: data to measure SNR on, e.g. output of the model
    n_chan = clean_data.shape[0]

    if inDezibel:
        return 1/n_chan * np.sum(10 * np.log10(np.linalg.norm(clean_data, axis = 1)/np.linalg.norm(clean_data-noisy_data, axis = 1)))
    else:
        return 1/n_chan * np.sum(np.linalg.norm(clean_data, axis = 1)**2/np.linalg.norm(clean_data-noisy_data, axis = 1)**2)


def create_EEG_DenoiseNet_dataset(config, artifact_type, debug = False):
    #TODO what about the upsampling?

    np.random.seed(0)  # for reproducibility
    snr_eog = np.linspace(-7,2, 10) # in dB
    snr_emg = np.linspace(-7,4, 12) # in dB

    Y = []
    X = []

    y = np.load(config['EEG_path'])        # clean segments


    #l_sum = 0

    ############### EOG ##################
    if artifact_type == 'EOG':
        n_EOG = np.load(config['EOG_path'])    #EOG noise segments
        nr_eog_segs = n_EOG.shape[0]
        selected_eeg_indices = np.random.choice(y.shape[0], nr_eog_segs, replace=False)
        selected_eeg_segments = y[selected_eeg_indices]

        #l = np.sum(x[0][:]**2)/np.sum(n_EOG[0][:]**2)/(10**(snr/5))
        #EOG: 3400 segments, EEG: 4514 segments
        #for i in range(nr_eog_segs):
            #l_sum += np.linalg.norm(selected_eeg_segments[i][:])/np.linalg.norm(n_EOG[i][:])/(10**(snr/10))
        #l = l_sum/nr_eog_segs
        for snr in snr_eog:
            l = np.linalg.norm(selected_eeg_segments.flatten())/np.linalg.norm(n_EOG.flatten())/(10**(snr/10))
            x = selected_eeg_segments + l*n_EOG
            if debug:
                snr_check_eog = calc_SNR(np.expand_dims(selected_eeg_segments.flatten(), 0), np.expand_dims(x.flatten(), 0))
            Y.append(selected_eeg_segments)
            X.append(x)

    elif artifact_type == 'EMG':
    ############### EMG ##################
        n_EMG = np.load(config['EMG_path'])    #EMG noise segments
        nr_emg_segs = n_EMG.shape[0]
        np.random.shuffle(n_EMG)
        selected_eeg_indices = np.random.choice(y.shape[0], nr_emg_segs - y.shape[0], replace=False)
        selected_eeg_segments = y[selected_eeg_indices]
        y_expanded = np.vstack((y,selected_eeg_segments))
        for snr in snr_emg:
            l = np.linalg.norm(y_expanded.flatten())/np.linalg.norm(n_EMG.flatten())/(10**(snr/10))
            x = y_expanded + l*n_EMG
            if debug:
                snr_check_emg = calc_SNR(np.expand_dims(y_expanded.flatten(), 0), np.expand_dims(x.flatten(), 0))
            X.append(x)
            Y.append(y_expanded)
    else:
        raise Exception("Wrong artifact type.")

    X = np.vstack(X)
    Y = np.vstack(Y)

    return X, Y


def get_rdm_EEG_segment_DenoiseNet (config, artifact_type, snr, debug=False):
    np.random.seed(0)  # for reproducibility

    y = np.load(config['EEG_path'])        # clean segments

    #l_sum = 0

    ############### EOG ##################
    if artifact_type == 'EOG':
        n_EOG = np.load(config['EOG_path'])    #EOG noise segments
        random_idx_eeg = np.random.choice(y.shape[0], 1, replace=False)
        random_idx_noise = np.random.choice(n_EOG.shape[0], 1, replace=False)
        selected_clean_segment = y[random_idx_eeg]
        selected_noise_segment = n_EOG[random_idx_noise]

        l = np.linalg.norm(selected_clean_segment)/np.linalg.norm(selected_noise_segment)/(10**(snr/10))
        x = selected_clean_segment + l*selected_noise_segment
        if debug:
            snr_check_eog = calc_SNR(np.expand_dims(selected_clean_segment, 0), np.expand_dims(x, 0))

    elif artifact_type == 'EMG':
    ############### EMG ##################
        n_EMG = np.load(config['EMG_path'])    #EMG noise segments
        random_idx_eeg = np.random.choice(y.shape[0], 1, replace=False)
        random_idx_noise = np.random.choice(n_EMG.shape[0], 1, replace=False)
        selected_clean_segment = y[random_idx_eeg]
        selected_noise_segment = n_EMG[random_idx_noise]

        l = np.linalg.norm(selected_clean_segment)/np.linalg.norm(selected_noise_segment)/(10**(snr/10))
        x = selected_clean_segment + l*selected_noise_segment
        if debug:
            snr_check_emg = calc_SNR(np.expand_dims(selected_clean_segment, 0), np.expand_dims(x, 0))

    else:
        raise Exception("Wrong artifact type.")

    return x, selected_clean_segment

# Model Tracer

In [30]:
class Model_Tracer():
    def __init__(self, monitor="loss", mode="min", do_save=False, root=None, prefix="checkpoint"):
        if mode not in ["min", "max"]:
            raise ValueError("mode can only be `min` or `max`")
        self.mode = mode
        self.monitor = monitor
        self.do_save = do_save
        self.bound = np.inf if mode == "min" else (-np.inf)
        self.root = os.path.join(expanduser("~"), "Downloads") if root is None else root
        self.prefix = prefix

    def on_epoch_end(self, epoch, logs):
        if ((self.mode == "min" and logs[self.monitor] < self.bound) or
            (self.mode == "max" and logs[self.monitor] > self.bound)
        ):
            print("Epoch {}: {} is improved from {:.6f} to {:.6f}".format(
                epoch, self.monitor, self.bound, logs[self.monitor]
            ))
            self.bound = logs[self.monitor]
            if self.do_save:
                filename = "{}.pth".format(self.prefix)
                torch.save(logs, os.path.join(self.root, filename))

# Model select


In [31]:
def model_select(model_class, model_cfg):
    if model_class == 'CLEEGN':
        model = CLEEGN(n_chan=model_cfg['n_chan'], fs=128, N_F=model_cfg['N_F'])
    elif model_class == 'xLSTM':
        xlstm_cfg = model_cfg['cfg']
        xlstm_cfg = OmegaConf.create(xlstm_cfg)
        xlstm_cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(xlstm_cfg), config=DaciteConfig(strict=True))
        xlstm_stack = xLSTMBlockStack(xlstm_cfg)
        model = xlstm_stack
    elif model_class == 'Seq2Seq':
        model = Seq2Seq(input_dim=model_cfg['n_chan'], hidden_dim=model_cfg['hidden_dim'], output_dim=model_cfg['n_chan'], num_layers=model_cfg['num_layers'])
    elif model_class == 'Seq2Seq_LSTM':
        model = Seq2SeqLSTM(input_dim=model_cfg['n_chan'], hidden_dim=model_cfg['hidden_dim'], output_dim=model_cfg['n_chan'], num_layers=model_cfg['num_layers'])
    elif model_class == 'LSTM':
        model = LSTM(input_dim=model_cfg['n_chan'], hidden_dim=model_cfg['hidden_dim'], output_dim=model_cfg['n_chan'], num_layers=model_cfg['num_layers'])
    elif model_class == 'Seq2Seq_Attention':
        model = Seq2SeqWithAttention(input_dim=model_cfg['n_chan'], hidden_dim=model_cfg['hidden_dim'], output_dim=model_cfg['n_chan'], num_layers=model_cfg['num_layers'])
    elif model_class == 'LSTM_Autoencoder':
        model = LSTMAutoencoder(input_dim=model_cfg['n_chan'], hidden_dim=model_cfg['hidden_dim'], latent_dim=model_cfg['latent_dim'], output_dim=model_cfg['n_chan'], num_layers=model_cfg['num_layers'])
    elif model_class == 'Transformer':
        model = TransformerDenoiser(input_dim=model_cfg['n_chan'], embed_dim = model_cfg['embed_dim'], num_heads = model_cfg['num_heads'], num_layers = model_cfg['num_layers'], hidden_dim = model_cfg['hidden_dim'], dropout = model_cfg['dropout'], max_len = model_cfg['max_len'])
    elif model_class == 'Autoencoder_CNN':
        model = ConvAutoencoder(input_channels=model_cfg['n_chan'])
    elif model_class == 'Autoencoder_CNN_Compress':
        model = ConvAutoencoder_Compress(input_channels=model_cfg['n_chan'])
    elif model_class == 'Autoencoder_CNN_LSTM':
        model = LSTMConvAutoencoder(input_channels=model_cfg['n_chan'], sequence_length=model_cfg['sequence_length'], hidden_dim=model_cfg['hidden_dim'], latent_dim=model_cfg['latent_dim'])
    elif model_class == 'Autoencoder_CNN_LSTM2':
        model = LSTMConvAutoencoder2(input_dim=model_cfg['n_chan'], hidden_dim = model_cfg['n_chan'], num_layers = model_cfg['num_layers'])
    elif model_class == 'Autoencoder_CNN_LSTM3':
        model = LSTMConvAutoencoder3(input_dim=model_cfg['n_chan'], num_layers = model_cfg['num_layers'])
    elif model_class == 'Autoencoder_CNN_LSTM4':
        model = LSTMConvAutoencoder4(input_dim=model_cfg['n_chan'], num_layers = model_cfg['num_layers'])
    elif model_class == 'Parallel_CNN_LSTM':
        model = Parallel_CNN_LSTM(lstm_model=LSTM(input_dim=model_cfg['n_chan'], hidden_dim=model_cfg['n_chan'], output_dim=model_cfg['n_chan'], num_layers=model_cfg['num_layers']), cnn_model=ConvAutoencoder(input_channels=model_cfg['n_chan']), n_chan=model_cfg['n_chan'], learn_concat=model_cfg['learn_concat'])
    elif model_class == 'IC_U_Net':
        #model = IC_U_NET(n_channels=model_cfg['n_chan'], bilinear=model_cfg['bilinear'])
        model = IC_U_NET(input_channels=model_cfg['n_chan'])
    elif model_class == 'OneD_Res_CNN':
        model = OneD_ResCNN(seq_length=model_cfg['seq_length'], batch_size=model_cfg['batch_size'], n_chan=model_cfg['n_chan'])
    return model

# TRAINING

In [32]:
def train(tra_loader, model, criterion, optimizer, model_class, normalize = False, ensemble_loss = False, use_wandb = False, verbose=1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()  # switch to train mode

    log = ""
    ep_time0 = time.time()
    epoch_loss = np.zeros((len(tra_loader), ))
    for i, (x_batch, y_batch) in enumerate(tra_loader):
        # print(i, x_batch.shape, y_batch.shape)
        x_batch, y_batch = x_batch.to(device, dtype=torch.float), y_batch.to(device, dtype=torch.float)

        optimizer.zero_grad()

        if normalize:
            #TODO: either restore (output times std) and calc loss on restored, or calc loss on normalized and only restore at inference
            # Calculate the standard deviation for each batch along channel and sequence dimensions
            std_per_batch = x_batch.std(dim=(2, 3))  # Keeps the result in shape [batch_size, 1]
            percentile_95_per_batch = torch.quantile(torch.abs(x_batch), 0.95, dim=3).squeeze(-1)
            for j in range(std_per_batch.shape[0]):
                x_batch[j] = x_batch[j]/percentile_95_per_batch[j]
                y_batch[j] = y_batch[j]/percentile_95_per_batch[j]


        if model_class == "Seq2Seq" or model_class == 'Transformer' or model_class == "Seq2Seq_Attention":
            output = model(x_batch, y_batch)
        else:
            output = model(x_batch)

        #if normalize:
            #for j in range(std_per_batch.shape[0]):
                #output[j] = output[j]*std_per_batch[j]

        if ensemble_loss:
            v_x, v_y, a_x, a_y, fft_x, fft_y = calc_vel_acc_freq(output, y_batch)
            loss_ampl = criterion(output, y_batch)
            loss_vel = criterion(v_x, v_y)
            loss_acc = criterion(a_x, a_y)
            loss_freq = criterion(fft_x, fft_y)
            loss = loss_ampl + loss_vel + loss_acc + loss_freq
        else:
            loss = criterion(output, y_batch)

        loss.backward()
        #if use_wandb:
            #log_gradients_to_wandb(model)
        optimizer.step()

        epoch_loss[i] = loss.item()

        if (np.isnan(loss.item())):
            print(x_batch)
            print(x_batch)
        if verbose:
            print("\r{}".format(" " * len(log)), end="")
            log = "\r{}/{} - {:.4f} s - loss: {:.4f} - acc: nan".format(
                i + 1, len(tra_loader), time.time() - ep_time0, epoch_loss[i]
            )
            print(log, end="")

            #if (i == 0 or i == len(tra_loader)-1 or i%10 == 0):
            if (0==1):
                x_b_nump = x_batch.numpy()
                y_b_nump = y_batch.numpy()
                out = output.detach().numpy()

                plt.plot(x_b_nump[0,:,:][0,:], label = 'x')
                plt.plot(y_b_nump[0,:,:][0,:], label = 'y')
                plt.plot(out[0,:,:][0,:], label = 'out')
                plt.legend()
                plt.savefig("test.pdf", format="pdf", bbox_inches="tight")
                plt.show()

                # CLEEGN:
                #plt.plot(x_b_nump[0,:,1,:][0,:])
                #plt.plot(y_b_nump[0,:,1,:][0,:])
                #plt.plot(out[0,:,1,:][0,:])
                #plt.show()


                # mse = np.zeros(18)
                # for i in range(18):
                #     plt.plot(x_b_nump[0,:,i,:][0,:])
                #     plt.plot(y_b_nump[0,:,i,:][0,:])
                #     plt.plot(out[0,:,i,:][0,:])
                #     plt.show()
                #     mse[i] = np.mean((y_b_nump[0,:,i,:][0,:]-out[0,:,i,:][0,:])**2)

    return epoch_loss.mean(axis=0)

In [33]:
def val(val_loader, model, criterion, model_class, normalize = False, verbose=0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval() # switch to evaluation mode

    log = ""
    ep_time0 = time.time()
    epoch_loss = np.zeros((len(val_loader), ))
    for i, (x_batch, y_batch) in enumerate(val_loader):
        x_batch, y_batch = x_batch.to(device, dtype=torch.float), y_batch.to(device, dtype=torch.float)

        if normalize:
            # Calculate the standard deviation for each batch along channel and sequence dimensions
            std_per_batch = x_batch.std(dim=(2, 3))  # Keeps the result in shape [batch_size, 1]
            percentile_95_per_batch = torch.quantile(torch.abs(x_batch), 0.95, dim=3).squeeze(-1)
            for j in range(std_per_batch.shape[0]):
                x_batch[j] = x_batch[j]/percentile_95_per_batch[j]
                y_batch[j] = y_batch[j]/percentile_95_per_batch[j]

        with torch.no_grad():
            if model_class == "Seq2Seq" or model_class == "Seq2Seq_Attention" or model_class == 'Transformer':
                output = model(x_batch, y_batch)
            else:
                output = model(x_batch)

        #if normalize:
            #for j in range(std_per_batch.shape[0]):
                #output[j] = output[j]*std_per_batch[j]

        # output = x_batch # DEBUG
        loss = criterion(output, y_batch)

        epoch_loss[i] = loss.item()
        if verbose:
            print("\r{}".format(" " * len(log)), end="")
            log = "\r{}/{} - {:.4f} s - loss: {:.4f} - acc: nan".format(
                i + 1, len(val_loader), time.time() - ep_time0, epoch_loss[i]
            )
            print(log, end="")
    return epoch_loss.mean(axis=0)

In [34]:
def log_gradients_to_wandb(model):
    for name, param in model.named_parameters():
        if param.grad is not None:
            wandb.log({f"grad_norms/{name}": param.grad.norm().item()})

def calc_vel_acc_freq(x, y, freq = 128):
    dt = 1.0 / freq  # Time step
    dt = 1.0
    # Velocity
    v_x = torch.diff(x, n=1, dim=-1) / dt
    v_y = torch.diff(y, n=1, dim=-1) / dt

    # Acceleration
    a_x = torch.diff(v_x, n=1, dim=-1) / dt
    a_y = torch.diff(v_y, n=1, dim=-1) / dt

    # Pad the reconstructed velocity and acceleration for consistent size
    v_x = F.pad(v_x, (0, 1), mode='constant')
    v_y = F.pad(v_y, (0, 1), mode='constant')
    a_x = F.pad(a_x, (0, 2), mode='constant')
    a_y = F.pad(a_y, (0, 2), mode='constant')

    # Frequency estimate using Fourier Transform and calculate MSE for frequency spectrum
    fft_x = torch.abs(torch.fft.fft(x, dim=-1))
    fft_y = torch.abs(torch.fft.fft(y, dim=-1))
    return v_x, v_y, a_x, a_y, fft_x, fft_y

In [35]:
USE_WANDB = True

In [36]:
def main_fct(config = None):
    import argparse

    if USE_WANDB:
        wandb.init(config=config)
        config = wandb.config
        MODEL_CLASS = config.model_class
    else:
        MODEL_CLASS = config['model_class']

    DATASET = 'TUH'    # either 'TUH' or 'BCI' or 'DenoiseNet'
    artifact_type = 'EOG'
    ensemble_loss = False


    if DATASET == 'BCI':
        config_path = '/content/drive/My Drive/A_EEG/CLEEGN/configs/bci-challenge/config.yml'
        model_config_path = '/content/drive/My Drive/A_EEG/CLEEGN/configs/bci-challenge/model_config.yml'

    elif DATASET == 'TUH':
        config_path = '/content/drive/My Drive/A_EEG/CLEEGN/configs/tusz/config.yml'
        model_config_path = '/content/drive/My Drive/A_EEG/CLEEGN/configs/tusz/model_config.yml'

    elif DATASET == 'DenoiseNet':
        config_path = '/content/drive/My Drive/A_EEG/CLEEGN/configs/EEG_DenoiseNet/config.yml'
        model_config_path = '/content/drive/My Drive/A_EEG/CLEEGN/configs/EEG_DenoiseNet/model_config.yml'


    model_name = yaml.safe_load(Path(config_path).read_text())['model_name']
    cfg_dataset = yaml.safe_load(Path(config_path).read_text())['Dataset']
    cfg_general = yaml.safe_load(Path(config_path).read_text())
    cfg_model = yaml.safe_load(Path(model_config_path).read_text())[MODEL_CLASS]

    SFREQ      = cfg_dataset["sfreq"]
    normalize  = cfg_dataset["normalize"]
    NUM_EPOCHS = cfg_general['epochs']
    BATCH_SIZE = cfg_model['batch_size']
    LR         = cfg_model["learning_rate"]
    scheduler_type = cfg_model["scheduler"]['type']

    if scheduler_type == 'MultiStepLR':
        scheduler_milestones = cfg_model["scheduler"]['milestones']
    if scheduler_type == 'MultiStepLR' or scheduler_type == 'ExponentialLR':
        scheduler_gamma = cfg_model["scheduler"]['gamma']


    # Save path
    if cfg_general["save_path"] is None:
        if DATASET == 'DenoiseNet':
            SAVE_PATH = '/content/drive/My Drive/A_EEG/CLEEGN/logs/' + DATASET + '_' + artifact_type + '/' + MODEL_CLASS
        else:
            SAVE_PATH = '/content/drive/My Drive/A_EEG/CLEEGN/logs/' + DATASET + '/' + MODEL_CLASS

        if not os.path.exists(SAVE_PATH):
            try:
                os.makedirs(SAVE_PATH)
            except Exception as e:
                print(f"Failed to create directory '{SAVE_PATH}': {e}")
    else:
        SAVE_PATH = cfg_general["save_path"]

    timestamp = datetime.now().strftime("%b%d_%H-%M-%S")

    if DATASET == 'TUH' or DATASET == 'BCI':
        x_train, y_train = create_dataset(
            os.path.join(cfg_dataset["x_basepath"], cfg_dataset["x_fpath"]),
            os.path.join(cfg_dataset["y_basepath"], cfg_dataset["y_fpath"]),
            cfg_dataset["subjects_train"], tmin=cfg_dataset["tmin"], tmax=cfg_dataset["tmax"],
            ch_names=cfg_dataset["ch_names"], win_size=cfg_dataset["window_size"], stride=cfg_dataset["stride"]
        )
    elif DATASET == 'DenoiseNet':
        x, y = create_EEG_DenoiseNet_dataset(cfg_dataset, artifact_type, debug = True)
        x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=0.2, random_state=0, shuffle=True)
        x_train = np.expand_dims(x_train,1)
        y_train = np.expand_dims(y_train,1)
        x_valid = np.expand_dims(x_valid,1)
        y_valid = np.expand_dims(y_valid,1)


    x_train = np2TT(np.expand_dims(x_train, axis=1))
    y_train = np2TT(np.expand_dims(y_train, axis=1))

    if MODEL_CLASS == 'xLSTM':
        x_train, y_train = x_train.permute(0,1,3,2).squeeze(), y_train.permute(0,1,3,2).squeeze()

    if DATASET == 'TUH' or DATASET == 'BCI':
        x_valid, y_valid = create_dataset(
            os.path.join(cfg_dataset["x_basepath"], cfg_dataset["x_fpath"]),
            os.path.join(cfg_dataset["y_basepath"], cfg_dataset["y_fpath"]),
            cfg_dataset["subjects_val"], tmin=cfg_dataset["tmin"], tmax=cfg_dataset["tmax"],
            ch_names=cfg_dataset["ch_names"], win_size=cfg_dataset["window_size"], stride=cfg_dataset["stride"]
        )

    x_valid = np2TT(np.expand_dims(x_valid, axis=1))
    y_valid = np2TT(np.expand_dims(y_valid, axis=1))

    if MODEL_CLASS == 'xLSTM':
        x_valid, y_valid = x_valid.permute(0,1,3,2).squeeze(), y_valid.permute(0,1,3,2).squeeze()

    print(x_train.size(), y_train.size(), x_valid.size(), y_valid.size())

    trainset = torch.utils.data.TensorDataset(x_train, y_train)
    tra_loader = torch.utils.data.DataLoader(
        trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True
    )
    validset = torch.utils.data.TensorDataset(x_valid, y_valid)
    val_loader = torch.utils.data.DataLoader(
        validset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True
    )



    model = model_select(MODEL_CLASS, cfg_model).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print ('Total # of modelparameters: ', str(total_params))

    #summary(model, input_size=(BATCH_SIZE, 1, x_train.size()[2], x_train.size()[3]))

    ckpts = [
        Model_Tracer(monitor="loss", mode="min"),
        Model_Tracer(monitor="val_loss", mode="min", do_save=True, root=SAVE_PATH, prefix= MODEL_CLASS + '_' + timestamp),
    ]
    criteria = nn.MSELoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    #optimizer = torch.optim.SGD(model.parameters(), lr=LR)

    if scheduler_type == 'MultiStepLR':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer = optimizer, milestones = scheduler_milestones, gamma = scheduler_gamma)
    elif scheduler_type == 'ExponentialLR':
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer = optimizer, gamma=scheduler_gamma)
    else:
        scheduler = None

    tra_time0 = time.time()
    loss_curve = {"epoch": [], "loss": [], "val_loss": []}


    for epoch in range(NUM_EPOCHS):
        loss = train(tra_loader, model, criteria, optimizer, MODEL_CLASS, normalize, ensemble_loss, USE_WANDB)

        """ validation """
        val_loss = val(val_loader, model, criteria, MODEL_CLASS, normalize)
        lr = optimizer.param_groups[-1]['lr']
        optimizer.step()

        if scheduler is not None:
            scheduler.step()

        print("\rEpoch {}/{} - {:.2f} s - loss: {:.4f} - val_loss: {:.4f} - lr: {:e}".format(
            epoch + 1, NUM_EPOCHS, time.time() - tra_time0, loss, val_loss, lr
        ))
        state = dict(
            epoch=epoch + 1, min_loss=ckpts[0].bound, min_vloss=ckpts[1].bound,
            state_dict=model.state_dict(), loss=loss, val_loss=val_loss, learning_rate=lr
        )
        for ckpt in ckpts:
            ckpt.on_epoch_end(epoch + 1, state)

        if USE_WANDB:
            wandb.log({"epoch": epoch, "loss": loss, "val_loss": val_loss})

        loss_curve["epoch"].append(epoch + 1)
        loss_curve["loss"].append(loss)
        loss_curve["val_loss"].append(val_loss)
    ### End_Of_Train
    savemat(os.path.join(SAVE_PATH, "loss_curve.mat"), loss_curve)


In [None]:
if __name__ == "__main__":
    if USE_WANDB:
        wandb.login()

        sweep_config = {
                'method': 'grid',
            }

        parameters_dict = {
        'model_class': {
            'values': ['Autoencoder_CNN_LSTM3', 'CLEEGN', 'IC_U_Net', 'OneD_Res_CNN'] #['LSTM', 'Autoencoder_CNN', 'xLSTM', 'Autoencoder_CNN_LSTM2', 'Autoencoder_CNN_LSTM3', 'Autoencoder_CNN_LSTM4', 'Parallel_CNN_LSTM', 'CLEEGN', 'IC_U_Net', 'OneD_Res_CNN']
            },
        }

        sweep_config['parameters'] = parameters_dict

        sweep_id = wandb.sweep(sweep_config, project="EEG_Denoising")

        wandb.agent(sweep_id, main_fct)
    else:
        config = {
        'model_class': 'Autoencoder_CNN_LSTM3' #['LSTM', 'Autoencoder_CNN', 'xLSTM', 'Autoencoder_CNN_LSTM2', 'Autoencoder_CNN_LSTM3', 'Autoencoder_CNN_LSTM4', 'Parallel_CNN_LSTM', 'CLEEGN', 'IC_U_Net', 'OneD_Res_CNN']
            }
        main_fct(config)

Create sweep with ID: fj2kcvbp
Sweep URL: https://wandb.ai/team_gtb/EEG_Denoising/sweeps/fj2kcvbp


[34m[1mwandb[0m: Agent Starting Run: imxcicdz with config:
[34m[1mwandb[0m: 	model_class: Autoencoder_CNN_LSTM3


  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_terms[0]), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_terms[0]), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0

torch.Size([3237, 1, 18, 512]) torch.Size([3237, 1, 18, 512]) torch.Size([1494, 1, 18, 512]) torch.Size([1494, 1, 18, 512])
Total # of modelparameters:  1474450




Epoch 1/60 - 3.84 s - loss: 0.0546 - val_loss: 0.0276 - lr: 1.000000e-03
Epoch 1: loss is improved from inf to 0.054576
Epoch 1: val_loss is improved from inf to 0.027574
Epoch 2/60 - 7.80 s - loss: 0.0218 - val_loss: 0.0169 - lr: 1.000000e-03
Epoch 2: loss is improved from 0.054576 to 0.021752
Epoch 2: val_loss is improved from 0.027574 to 0.016929
Epoch 3/60 - 10.28 s - loss: 0.0156 - val_loss: 0.0131 - lr: 1.000000e-03
Epoch 3: loss is improved from 0.021752 to 0.015581
Epoch 3: val_loss is improved from 0.016929 to 0.013124
Epoch 4/60 - 12.81 s - loss: 0.0124 - val_loss: 0.0103 - lr: 1.000000e-03
Epoch 4: loss is improved from 0.015581 to 0.012438
Epoch 4: val_loss is improved from 0.013124 to 0.010288
Epoch 5/60 - 15.31 s - loss: 0.0105 - val_loss: 0.0086 - lr: 1.000000e-03
Epoch 5: loss is improved from 0.012438 to 0.010476
Epoch 5: val_loss is improved from 0.010288 to 0.008592
Epoch 6/60 - 18.14 s - loss: 0.0087 - val_loss: 0.0077 - lr: 1.000000e-03
Epoch 6: loss is improved fr

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇███
loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,59.0
loss,0.00059
val_loss,0.00069


[34m[1mwandb[0m: Agent Starting Run: 0by75vlj with config:
[34m[1mwandb[0m: 	model_class: CLEEGN


  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_terms[0]), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_terms[0]), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0

torch.Size([3237, 1, 18, 512]) torch.Size([3237, 1, 18, 512]) torch.Size([1494, 1, 18, 512]) torch.Size([1494, 1, 18, 512])
Total # of modelparameters:  10767
Epoch 1/60 - 4.80 s - loss: 0.0598 - val_loss: 0.0305 - lr: 1.000000e-03
Epoch 1: loss is improved from inf to 0.059767
Epoch 1: val_loss is improved from inf to 0.030513
Epoch 2/60 - 9.52 s - loss: 0.0214 - val_loss: 0.0192 - lr: 8.000000e-04
Epoch 2: loss is improved from 0.059767 to 0.021446
Epoch 2: val_loss is improved from 0.030513 to 0.019178
Epoch 3/60 - 15.12 s - loss: 0.0159 - val_loss: 0.0147 - lr: 6.400000e-04
Epoch 3: loss is improved from 0.021446 to 0.015946
Epoch 3: val_loss is improved from 0.019178 to 0.014655
Epoch 4/60 - 19.81 s - loss: 0.0131 - val_loss: 0.0123 - lr: 5.120000e-04
Epoch 4: loss is improved from 0.015946 to 0.013105
Epoch 4: val_loss is improved from 0.014655 to 0.012325
Epoch 5/60 - 24.58 s - loss: 0.0113 - val_loss: 0.0098 - lr: 4.096000e-04
Epoch 5: loss is improved from 0.013105 to 0.011328

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,59.0
loss,0.00719
val_loss,0.00573


[34m[1mwandb[0m: Agent Starting Run: vpsrjea0 with config:
[34m[1mwandb[0m: 	model_class: IC_U_Net


  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_terms[0]), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_terms[0]), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0

torch.Size([3237, 1, 18, 512]) torch.Size([3237, 1, 18, 512]) torch.Size([1494, 1, 18, 512]) torch.Size([1494, 1, 18, 512])
Total # of modelparameters:  4214006
Epoch 1/60 - 5.98 s - loss: 0.1287 - val_loss: 0.1210 - lr: 1.000000e-03
Epoch 1: loss is improved from inf to 0.128716
Epoch 1: val_loss is improved from inf to 0.121036
Epoch 2/60 - 10.99 s - loss: 0.0953 - val_loss: 0.1159 - lr: 1.000000e-03
Epoch 2: loss is improved from 0.128716 to 0.095282
Epoch 2: val_loss is improved from 0.121036 to 0.115944
Epoch 3/60 - 16.54 s - loss: 0.0930 - val_loss: 0.1142 - lr: 1.000000e-03
Epoch 3: loss is improved from 0.095282 to 0.093022
Epoch 3: val_loss is improved from 0.115944 to 0.114249
Epoch 4/60 - 22.35 s - loss: 0.0920 - val_loss: 0.1132 - lr: 1.000000e-03
Epoch 4: loss is improved from 0.093022 to 0.091980
Epoch 4: val_loss is improved from 0.114249 to 0.113200
Epoch 5/60 - 27.47 s - loss: 0.0947 - val_loss: 0.1169 - lr: 1.000000e-03
Epoch 6/60 - 33.17 s - loss: 0.0922 - val_loss: 

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
loss,█▂▂▂▂▁▁▁▁▁▁▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▅▄▅▃▃▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,59.0
loss,0.08807
val_loss,0.11004


[34m[1mwandb[0m: Agent Starting Run: sxm4xrxv with config:
[34m[1mwandb[0m: 	model_class: OneD_Res_CNN


  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_terms[0]), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_terms[0]), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  y_raw = mne.io.read_raw_eeglab(y_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0)
  x_raw = mne.io.read_raw_eeglab(x_fpath.format(*fmt_term), verbose=0

torch.Size([3237, 1, 18, 512]) torch.Size([3237, 1, 18, 512]) torch.Size([1494, 1, 18, 512]) torch.Size([1494, 1, 18, 512])
Total # of modelparameters:  325891
Epoch 1/60 - 13.66 s - loss: 0.0900 - val_loss: 0.0599 - lr: 1.000000e-03
Epoch 1: loss is improved from inf to 0.089981
Epoch 1: val_loss is improved from inf to 0.059879
Epoch 2/60 - 27.36 s - loss: 0.0501 - val_loss: 0.0479 - lr: 1.000000e-03
Epoch 2: loss is improved from 0.089981 to 0.050105
Epoch 2: val_loss is improved from 0.059879 to 0.047946
Epoch 3/60 - 41.14 s - loss: 0.0385 - val_loss: 0.0394 - lr: 1.000000e-03
Epoch 3: loss is improved from 0.050105 to 0.038485
Epoch 3: val_loss is improved from 0.047946 to 0.039352
Epoch 4/60 - 55.02 s - loss: 0.0330 - val_loss: 0.0372 - lr: 1.000000e-03
Epoch 4: loss is improved from 0.038485 to 0.032957
Epoch 4: val_loss is improved from 0.039352 to 0.037196
Epoch 5/60 - 68.90 s - loss: 0.0317 - val_loss: 0.0377 - lr: 1.000000e-03
Epoch 5: loss is improved from 0.032957 to 0.031

# TEST/INFERENCE

In [None]:
electrode = ['FP1', 'FP2', 'F3', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'FZ', 'CZ', 'PZ']
""" pyplot waveform visualization """
def viewARA(tstmps, data_colle, ref_i, electrode, titles=None, colors=None, alphas=None, ax=None):
    n_data = len(data_colle)
    titles = ["" for di in range(n_data)] if titles is None else titles
    alphas = [0.5 for di in range(n_data)] if alphas is None else alphas
    if colors is None:
        cmap_ = plt.cm.get_cmap("tab20", n_data)
        colors = [rgb2hex(cmap_(di)) for di in range(n_data)]

    picks_chs = ["FP1", "FP2", "F7", "T4", "O2"]
    picks = [electrode.index(c) for c in picks_chs]
    for di in range(n_data):
        data_colle[di] = data_colle[di][picks, :]
    if ax is None:
        ax = plt.subplot()
    for ii, ch_name in enumerate(picks_chs):
        offset = len(picks) - ii - 1
        norm_coef = 0.25 / np.abs(data_colle[ref_i][ii]).max()
        for di in range(n_data):
            eeg_dt = data_colle[di]
            ax.plot(tstmps, eeg_dt[ii] * norm_coef + offset,
                label=None if ii else titles[di], color=colors[di], alpha=alphas[di],
                linewidth=3 if alphas[di] > 0.6 else 1.5, # default=1.5
            )
    ax.set_xlim(tstmps[0], tstmps[-1])
    ax.set_ylim(-0.5, len(picks) - 0.5)

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_yticks(np.arange(len(picks)))
    ax.set_yticklabels(picks_chs[::-1], fontsize=20)
    ax.legend(
        bbox_to_anchor=(0, 1.02, 1, 0.2),
        loc="lower right", borderaxespad=0, ncol=3, fontsize=20
    )

def ar_through_model(eeg_data, model, window_size, stride):
    model.eval()

    noiseless_eeg = np.zeros(eeg_data.shape, dtype=np.float32)
    hcoef = np.zeros(eeg_data.shape[1], dtype=np.float32)

    hwin = signal.windows.hann(window_size) + 1e-9
    for i in range(0, noiseless_eeg.shape[1], stride):
        tstap, LAST_FRAME = i, False
        segment = eeg_data[:, tstap: tstap + window_size]
        if segment.shape[1] != window_size:
            tstap = noiseless_eeg.shape[1] - window_size
            segment = eeg_data[:, tstap:]
            LAST_FRAME = True
        with torch.no_grad():
            segment = np.expand_dims(segment, axis=0)
            data = np2TT(np.expand_dims(segment, axis=0))
            if MODEL_CLASS == 'xLSTM':
                data = data.permute(0,1,3,2).squeeze(0)  #ADDED
            data = data.to(device, dtype=torch.float)
            pred_segment = model(data)
            if MODEL_CLASS == 'xLSTM':
                pred_segment = pred_segment.permute(0,2,1)
            pred_segment = np.array(pred_segment.cpu()).astype(np.float32)   #pred_segment [1, n_chan, seq_length]
        noiseless_eeg[:, tstap: tstap + window_size] += pred_segment.squeeze() * hwin
        hcoef[tstap: tstap + window_size] += hwin

        if LAST_FRAME:
            break
    noiseless_eeg /= hcoef

    return noiseless_eeg



if __name__ == "__main__":
    MODEL_FILE_NAME = 'OneD_Res_CNN_Nov08_10-16-12.pth'
    TEST_DATA_PATH = '/content/drive/MyDrive/A_EEG/CLEEGN/sampleData/Data_S016_norm.mat'

    model_path = os.path.join('/content/drive/MyDrive/A_EEG/CLEEGN', 'logs', DATASET, MODEL_CLASS, MODEL_FILE_NAME)
    test_data = loadmat(TEST_DATA_PATH)
    dt_polluted, dt_ref = test_data["x_test"], test_data["y_test"]

    ### temporary fixed mode
    state_path = os.path.join(model_path)
    state = torch.load(state_path, map_location="cpu")

    if MODEL_CLASS == 'xLSTM':
        xlstm_stack = xLSTMBlockStack(xlstm_cfg)
    #model = CLEEGN(n_chan=x_train.size()[2], fs=SFREQ, N_F=x_train.size()[2]).to(device)
    #model = xlstm_stack.to(device)
    model = model_select(MODEL_CLASS, cfg_model).to(device)
    model.load_state_dict(state["state_dict"])
    #model.load_state_dict(torch.load(model_path))
    dt_cleegn = ar_through_model(
        dt_polluted, model, math.ceil(4.0 * 128.0), math.ceil(1 * 128.0)
    )

    start = 6000
    x_min, x_max = start, start + 500
    x_data = dt_polluted[:, x_min: x_max]
    y_data = dt_ref[:, x_min: x_max]
    p_data = dt_cleegn[:, x_min: x_max]
    fig, ax = plt.subplots(1, 1, figsize=(16, 9))
    viewARA(
        np.linspace(0, math.ceil(x_data.shape[-1] / 128.0), x_data.shape[-1]),
        [x_data, y_data, y_data, p_data], 1, electrode,
        titles=["Original", "", "Reference", MODEL_CLASS], colors=["gray", "gray", "red", "blue"], alphas=[0.5, 0, 0.8, 0.8], ax=ax
    )
    plt.savefig("test.pdf", format="pdf", bbox_inches="tight")
    plt.show()

In [None]:
MODEL_FILE_NAME = 'OneD_Res_CNN_Nov13_13-35-41.pth'
artifact_type = 'EOG'
snr_synthetic_testData = 4 #in dezibel
plt_interval = [0, 512]


if DATASET == 'BCI':
    #config_path = 'configs/BCI_KAGGLE/config.yml'
    #model_config_path = 'configs/BCI_KAGGLE/model_config.yml'
    electrode = ['Fp1', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'POz', 'PO8', 'O1', 'O2']
    picks_chs = ["Fp1", "Fp2", "F7", "T7", "O2"]
    TEST_DATA_PATH = 'sampleData\Data_S14.mat'  #TBD

elif DATASET == 'TUH':
    #config_path = 'configs/TUH/config.yml'
    #model_config_path = 'configs/TUH/model_config.yml'
    electrode = ['FP1', 'FP2', 'F3', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'FZ', 'CZ', 'PZ']
    picks_chs = ["FP1", "FP2", "F7", "T4", "O2"]
    TEST_DATA_PATH = '/content/drive/MyDrive/A_EEG/CLEEGN/sampleData/Data_S016_norm.mat'

elif DATASET == 'DenoiseNet':
    #config_path = 'configs/EEG_DenoiseNet/config.yml'
    #model_config_path = 'configs/EEG_DenoiseNet/model_config.yml'
    electrode = ['ch1']
    picks_chs = ['ch1']


cfg_model = yaml.safe_load(Path(model_config_path).read_text())[MODEL_CLASS]
cfg_dataset = yaml.safe_load(Path(config_path).read_text())['Dataset']
SFREQ      = cfg_dataset["sfreq"]


os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

from xlstm import (
    xLSTMBlockStack,
    xLSTMBlockStackConfig,
    mLSTMBlockConfig,
    mLSTMLayerConfig,
    sLSTMBlockConfig,
    sLSTMLayerConfig,
    FeedForwardConfig,
)




torch_dtype_map: dict[str, torch.dtype] = {
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
    "float16": torch.float16,
}
xlstm_cfg = xLSTMBlockStackConfig(
    mlstm_block=mLSTMBlockConfig(
        mlstm=mLSTMLayerConfig(
            conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=3
        )
    ),
    slstm_block=sLSTMBlockConfig(
        slstm=sLSTMLayerConfig(
            backend="vanilla",
            num_heads=1,
            conv1d_kernel_size=4,
            bias_init="powerlaw_blockdependent",
        ),
        feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
    ),
    context_length=512,
    num_blocks=1,
    embedding_dim=18,
    slstm_at=[0],
)



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


""" pyplot waveform visualization """
def viewARA(tstmps, data_colle, ref_i, electrode, titles=None, colors=None, alphas=None, ax=None, picks_channel=None):
    n_data = len(data_colle)
    titles = ["" for di in range(n_data)] if titles is None else titles
    alphas = [0.5 for di in range(n_data)] if alphas is None else alphas
    if colors is None:
        cmap_ = plt.cm.get_cmap("tab20", n_data)
        colors = [rgb2hex(cmap_(di)) for di in range(n_data)]


    picks = [electrode.index(c) for c in picks_chs]
    for di in range(n_data):
        data_colle[di] = data_colle[di][picks, :]
    if ax is None:
        ax = plt.subplot()
    for ii, ch_name in enumerate(picks_chs):
        offset = len(picks) - ii - 1
        norm_coef = 0.25 / np.abs(data_colle[ref_i][ii]).max()
        for di in range(n_data):
            eeg_dt = data_colle[di]
            ax.plot(tstmps, eeg_dt[ii] * norm_coef + offset,
                label=None if ii else titles[di], color=colors[di], alpha=alphas[di],
                linewidth=3 if alphas[di] > 0.6 else 1.5, # default=1.5
            )
    ax.set_xlim(tstmps[0], tstmps[-1])
    ax.set_ylim(-0.5, len(picks) - 0.5)

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_yticks(np.arange(len(picks)))
    ax.set_yticklabels(picks_chs[::-1], fontsize=20)
    ax.legend(
        bbox_to_anchor=(0, 1.02, 1, 0.2),
        loc="lower right", borderaxespad=0, ncol=3, fontsize=20
    )

def ar_through_model(eeg_data, model, window_size, stride):
    model.eval()

    noiseless_eeg = np.zeros(eeg_data.shape, dtype=np.float32)
    hcoef = np.zeros(eeg_data.shape[1], dtype=np.float32)

    hwin = signal.windows.hann(window_size) + 1e-9
    for i in range(0, noiseless_eeg.shape[1], stride):
        tstap, LAST_FRAME = i, False
        segment = eeg_data[:, tstap: tstap + window_size]
        if segment.shape[1] != window_size:
            tstap = noiseless_eeg.shape[1] - window_size
            segment = eeg_data[:, tstap:]
            LAST_FRAME = True
        with torch.no_grad():
            segment = np.expand_dims(segment, axis=0)
            data = np2TT(np.expand_dims(segment, axis=0))
            if MODEL_CLASS == 'xLSTM':
                data = data.permute(0,1,3,2).squeeze(0)  #ADDED
            data = data.to(device, dtype=torch.float)
            pred_segment = model(data)
            if MODEL_CLASS == 'xLSTM':
                pred_segment = pred_segment.permute(0,2,1)
            pred_segment = np.array(pred_segment.cpu()).astype(np.float32)   #pred_segment [1, n_chan, seq_length]
        noiseless_eeg[:, tstap: tstap + window_size] += pred_segment.squeeze() * hwin
        hcoef[tstap: tstap + window_size] += hwin

        if LAST_FRAME:
            break
    noiseless_eeg /= hcoef

    return noiseless_eeg

def calc_SNR(clean_data, noisy_data, inDezibel = True):
    # clean data: reference data
    # noisy data: data to measure SNR on, e.g. output of the model
    n_chan = clean_data.shape[0]

    if inDezibel:
        return 1/n_chan * np.sum(10 * np.log10(np.linalg.norm(clean_data, axis = 1)/np.linalg.norm(clean_data-noisy_data, axis = 1)))
    else:
        return 1/n_chan * np.sum(np.linalg.norm(clean_data, axis = 1)**2/np.linalg.norm(clean_data-noisy_data, axis = 1)**2)

def calc_MSE(x, y):
    return 1/x.shape[0] * np.sum(1/x.shape[1] * np.linalg.norm(x - y, axis = 1)**2)


if __name__ == "__main__":
    import argparse

    model_path = os.path.join(os.path.abspath(os.getcwd()), 'logs', DATASET, MODEL_CLASS, MODEL_FILE_NAME)

    if DATASET == 'TUH' or DATASET == 'BCI':
        test_data = loadmat(TEST_DATA_PATH)
        noisy_data, reference_data = test_data["x_test"], test_data["y_test"]
    elif DATASET == 'DenoiseNet':
        noisy_data, reference_data = get_rdm_EEG_segment_DenoiseNet(cfg_dataset, artifact_type, snr_synthetic_testData)
        percentile_95 = np.quantile(np.abs(noisy_data.squeeze()), 0.95)
        noisy_data = noisy_data/percentile_95
        reference_data = reference_data/percentile_95


    state_path = os.path.join(model_path)
    state = torch.load(state_path, map_location="cpu")

    #xlstm_stack = xLSTMBlockStack(xlstm_cfg)

    model = model_select(MODEL_CLASS, cfg_model)
    model.load_state_dict(state["state_dict"])

    reconstructed_data = ar_through_model(
        noisy_data, model, math.ceil(4.0 * 128.0), math.ceil(1 * 128.0)
    )



    start = plt_interval[0]
    x_min, x_max = start, start + plt_interval[1]
    x_data = noisy_data[:, x_min: x_max]
    y_data = reference_data[:, x_min: x_max]
    p_data = reconstructed_data[:, x_min: x_max]

    #TODO SNR, MSE berechnen und printen
    snr = calc_SNR(y_data, p_data, inDezibel=False)
    snr_dB = calc_SNR(y_data, p_data, inDezibel=True)
    mse = calc_MSE(y_data, p_data)

    print('Data points of segment: ' + str(p_data.shape[1]))
    print(f'MSE: {mse:.5f}')
    print(f'SNR: {snr_dB:.2f}dB (or {snr:.2f})')

    fig, ax = plt.subplots(1, 1, figsize=(16, 9))
    viewARA(
        np.linspace(0, math.ceil(x_data.shape[-1] / 128.0), x_data.shape[-1]),
        [x_data, y_data, y_data, p_data], 1, electrode,
        titles=["Original", "", "Reference", MODEL_CLASS], colors=["gray", "gray", "red", "blue"], alphas=[0.5, 0, 0.8, 0.8], ax=ax,
        picks_channel = picks_chs
    )
    plt.savefig("inference.pdf", format="pdf", bbox_inches="tight")
    plt.show()