# Multimodal approach using both RAW EEG Features & Spectrograms

#### This is the inference notebook for Multimodal based solution in Kaggle's Brain comp. It scores 0.607 on GroupKfold 5 Folds split.

#### Training Notebook link: https://www.kaggle.com/nischaydnk/training-multimodal-1d-2d-approach-eegs


#### PS: This approach is somewhat in continuation of what I have shared earlier in these notebooks .You can also refer my previous work to study more about the 1D approach I have used in the notebook. Also, I have used some of the code from [@moth](https://www.kaggle.com/alejopaullier) baseline with 2D approach [notebook](https://www.kaggle.com/code/alejopaullier/hms-efficientnetb0-pytorch-inference) & ideas shared by [@Chris](https://www.kaggle.com/cdeotte)



#### My previous notebooks:

##### 1D EEGNet training: https://www.kaggle.com/code/nischaydnk/lightning-1d-eegnet-training-pipeline-hbs
##### 1D EEGNet Submission: https://www.kaggle.com/code/nischaydnk/hms-submission-1d-eegnet-pipeline-lightning


#### Overall pipeline in this notebook is somewhat like this.
![img2](https://www.googleapis.com/download/storage/v1/b/kaggle-forum-message-attachments/o/inbox%2F4712534%2Fc4e39641fa8ec13588b62065ae1e9b58%2FScreenshot%202024-02-10%20at%2010.40.36%20PM.png?generation=1707585942607444&alt=media)

## Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import os
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, BackboneFinetuning, EarlyStopping
from torch.utils.data import Dataset, DataLoader
from sklearn import model_selection
import torchvision.transforms as transforms
import torchvision.io 
import librosa
from PIL import Image
import albumentations as alb
import torch.multiprocessing as mp
import warnings
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
warnings.filterwarnings('ignore')


## Config Class

In [None]:

class Config:
    
    use_aug = False
    num_classes = 6
    batch_size = 32
    epochs = 18
    PRECISION = 16    
    PATIENCE = 10    
    seed = 2024
    pretrained = False            
    weight_decay = 1e-3
    use_mixup = False
    mixup_alpha = 0.02   
    num_channels = 8
    backbone_2d = 'tf_efficientnet_b0'
    data_root = "/kaggle/input/hms-harmful-brain-activity-classification/"
    kaggle_spec_path = "/kaggle/input/hms-harmful-brain-activity-classification/test_spectrograms/" 
    TRAIN_EEGS = "/home/nischay//brain/Data/spec/EEG_Spectrograms/"
    LR = 7e-3
    processed_train = None
    model_dir = '/kaggle/input/mulitimod-exp1-5gkf-b08ch-cv606'
    trn_folds = [0,1,2,3,4]
    

## Parquet to EEG Signals Numpy Processing

In [None]:
USE_WAVELET = None 

NAMES = ['LL','LP','RP','RR']

FEATS2 = [['Fp1','F7','T3','T5','O1'],
         ['Fp1','F3','C3','P3','O1'],
         ['Fp2','F8','T4','T6','O2'],
         ['Fp2','F4','C4','P4','O2']]



def maddest(d, axis: int = None):
    """
    Denoise function.
    """
    return np.mean(np.absolute(d - np.mean(d, axis)), axis)

def denoise(x: np.ndarray, wavelet: str = 'haar', level: int = 1): 
    coeff = pywt.wavedec(x, wavelet, mode="per") # multilevel 1D Discrete Wavelet Transform of data.
    sigma = (1/0.6745) * maddest(coeff[-level])
    uthresh = sigma * np.sqrt(2*np.log(len(x)))
    coeff[1:] = (pywt.threshold(i, value=uthresh, mode='hard') for i in coeff[1:])
    output = pywt.waverec(coeff, wavelet, mode='per')
    return output

def spectrogram_from_eeg(parquet_path, display=False):
    
    # LOAD MIDDLE 50 SECONDS OF EEG SERIES
    eeg = pd.read_parquet(parquet_path)
    middle = (len(eeg)-10_000)//2
    eeg = eeg.iloc[middle:middle+10_000]
    
    # VARIABLE TO HOLD SPECTROGRAM
    img = np.zeros((128,256,4),dtype='float32')
    
    if display:
        plt.figure(figsize=(10,7))
        
    signals = []
    
    for k in range(4):
        COLS = FEATS2[k]
        for kk in range(4):
        
            # COMPUTE PAIR DIFFERENCES
            x = eeg[COLS[kk]].values - eeg[COLS[kk+1]].values

            # FILL NANS
            m = np.nanmean(x)
            if np.isnan(x).mean()<1: x = np.nan_to_num(x,nan=m)
            else: x[:] = 0

            # DENOISE
            if USE_WAVELET:
                x = denoise(x, wavelet=USE_WAVELET)
            signals.append(x)

            # RAW SPECTROGRAM
            mel_spec = librosa.feature.melspectrogram(y=x, sr=200, hop_length=len(x)//256, 
                  n_fft=1024, n_mels=128, fmin=0, fmax=20, win_length=128)

            # LOG TRANSFORM
            width = (mel_spec.shape[1]//32)*32
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max).astype(np.float32)[:,:width]

            # STANDARDIZE TO -1 TO 1
            mel_spec_db = (mel_spec_db+40)/40 
            img[:,:,k] += mel_spec_db
                
        # AVERAGE THE 4 MONTAGE DIFFERENCES
        img[:,:,k] /= 4.0
        
        if display:
            plt.subplot(2,2,k+1)
            plt.imshow(img[:,:,k],aspect='auto',origin='lower')
            plt.title(f'EEG {eeg_id} - Spectrogram {NAMES[k]}')
            
    if display: 
        plt.show()
        plt.figure(figsize=(10,5))
        offset = 0
        for k in range(4):
            if k>0: offset -= signals[3-k].min()
            plt.plot(range(10_000),signals[k]+offset,label=NAMES[3-k])
            offset += signals[3-k].max()
        plt.legend()
        plt.title(f'EEG {eeg_id} Signals')
        plt.show()
        print(); print('#'*25); print()
        
    return img

def plot_spectrogram(spectrogram_path: str):
    """
    Source: https://www.kaggle.com/code/mvvppp/hms-eda-and-domain-journey
    Visualize spectrogram recordings from a parquet file.
    :param spectrogram_path: path to the spectrogram parquet.
    """
    sample_spect = pd.read_parquet(spectrogram_path)
    
    split_spect = {
        "LL": sample_spect.filter(regex='^LL', axis=1),
        "RL": sample_spect.filter(regex='^RL', axis=1),
        "RP": sample_spect.filter(regex='^RP', axis=1),
        "LP": sample_spect.filter(regex='^LP', axis=1),
    }
    
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 12))
    axes = axes.flatten()
    label_interval = 5
    for i, split_name in enumerate(split_spect.keys()):
        ax = axes[i]
        img = ax.imshow(np.log(split_spect[split_name]).T, cmap='viridis', aspect='auto', origin='lower')
        cbar = fig.colorbar(img, ax=ax)
        cbar.set_label('Log(Value)')
        ax.set_title(split_name)
        ax.set_ylabel("Frequency (Hz)")
        ax.set_xlabel("Time")

        ax.set_yticks(np.arange(len(split_spect[split_name].columns)))
        ax.set_yticklabels([column_name[3:] for column_name in split_spect[split_name].columns])
        frequencies = [column_name[3:] for column_name in split_spect[split_name].columns]
        ax.set_yticks(np.arange(0, len(split_spect[split_name].columns), label_interval))
        ax.set_yticklabels(frequencies[::label_interval])
    plt.tight_layout()
    plt.show()

In [None]:
def eeg_from_parquet(parquet_path, display=False):
    
    # EXTRACT MIDDLE 50 SECONDS
    eeg = pd.read_parquet(parquet_path, columns=FEATS)
    rows = len(eeg)
    offset = (rows-10_000)//2
    eeg = eeg.iloc[offset:offset+10_000]
    
    if display: 
        plt.figure(figsize=(10,5))
        offset = 0
    
    # CONVERT TO NUMPY
    data = np.zeros((10_000,len(FEATS)))
    for j,col in enumerate(FEATS):
        
        # FILL NAN
        x = eeg[col].values.astype('float32')
        m = np.nanmean(x)
        if np.isnan(x).mean()<1: x = np.nan_to_num(x,nan=m)
        else: x[:] = 0
            
        data[:,j] = x
        
        if display: 
            if j!=0: offset += x.max()
            plt.plot(range(10_000),x-offset,label=col)
            offset -= x.min()
            
    if display:
        plt.legend()
        name = parquet_path.split('/')[-1]
        name = name.split('.')[0]
        plt.title(f'EEG {name}',size=16)
        plt.show()
        
    return data

In [None]:
test = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/test.csv')
print('Test shape:',test.shape)
test.head()

In [None]:
all_eegs2 = {}
DISPLAY = 1
EEG_IDS2 = test.eeg_id.unique()
PATH2 = '/kaggle/input/hms-harmful-brain-activity-classification/test_eegs/'

FEATS = ['Fp1','T3','C3','O1','Fp2','C4','T4','O2']
# FEATS = ['Fp1', 'F3', 'C3', 'P3', 'F7', 'T3', 'T5', 'O1', 'Fz', 'Cz', 'Pz', 'Fp2', 'F4', 'C4', 'P4', 'F8', 'T4', 'T6', 'O2', 'EKG']
print(f'There are {len(FEATS)} raw eeg features')
print( list(FEATS) )

print('Processing Test EEG parquets...'); print()
for i,eeg_id in enumerate(EEG_IDS2):
        
    # SAVE EEG TO PYTHON DICTIONARY OF NUMPY ARRAYS
    data = eeg_from_parquet(f'{PATH2}{eeg_id}.parquet', i<DISPLAY)
    all_eegs2[eeg_id] = data

In [None]:
pl.seed_everything(Config.seed, workers=True)

In [None]:
def config_to_dict(cfg):
    return dict((name, getattr(cfg, name)) for name in dir(cfg) if not name.startswith('__'))

In [None]:
test = pd.DataFrame(EEG_IDS2, columns=['eeg_id'])
test2 = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/test.csv')

In [None]:
test_spec_ids = test2[test2['eeg_id'].isin(EEG_IDS2)].spectrogram_id.unique()

In [None]:
test['spectrogram_id'] = test_spec_ids

In [None]:
test2

In [None]:
test

In [None]:
import pandas as pd, numpy as np, os
import matplotlib.pyplot as plt

df = pd.read_csv(f'{Config.data_root}train.csv')
print( df.shape )


In [None]:

EEG_IDS = df.eeg_id.unique()

TARGETS = df.columns[-6:]
TARS = {'Seizure':0, 'LPD':1, 'GPD':2, 'LRDA':3, 'GRDA':4, 'Other':5}
TARS_INV = {x:y for y,x in TARS.items()}

train_df = df.groupby('eeg_id')[['spectrogram_id','spectrogram_label_offset_seconds']].agg({
    'spectrogram_id':'first',
    'spectrogram_label_offset_seconds':'min'
})
train_df.columns = ['spectogram_id','min']

aux = df.groupby('eeg_id')[['spectrogram_id','spectrogram_label_offset_seconds']].agg({
    'spectrogram_label_offset_seconds':'max'
})
train_df['max'] = aux

aux = df.groupby('eeg_id')[['patient_id']].agg('first')
train_df['patient_id'] = aux

aux = df.groupby('eeg_id')[TARGETS].agg('sum')
for label in TARGETS:
    train_df[label] = aux[label].values
    
y_data = train_df[TARGETS].values
y_data = y_data / y_data.sum(axis=1,keepdims=True)
train_df[TARGETS] = y_data

aux = df.groupby('eeg_id')[['expert_consensus']].agg('first')
train_df['target'] = aux

train = train_df.reset_index()
print('Train non-overlapp eeg_id shape:', train_df.shape )
train.head()


In [None]:
train.head()

In [None]:
Config.num_classes = len(TARS.keys())

In [None]:
train.describe()

In [None]:
from tqdm import tqdm
tqdm.pandas()

In [None]:
%%time
from glob import glob

READ_SPEC_FILES = True

paths_spectograms = glob(Config.kaggle_spec_path + "*.parquet")
print(f'There are {len(paths_spectograms)} spectrogram parquets')

if READ_SPEC_FILES:    
    all_spectrograms = {}
    for file_path in tqdm(paths_spectograms):
        aux = pd.read_parquet(file_path)
        name = int(file_path.split("/")[-1].split('.')[0])
        all_spectrograms[name] = aux.iloc[:,1:].values
        del aux
else:
    all_spectrograms = np.load(Config.PRE_LOADED_SPECTOGRAMS, allow_pickle=True).item()


In [None]:
%%time

paths_eegs = glob("/kaggle/input/hms-harmful-brain-activity-classification/test_eegs/" + "*.parquet")
print(f'There are {len(paths_eegs)} EEG spectrograms')
all_eegs = {}
counter = 0

for file_path in tqdm(paths_eegs):
    eeg_id = file_path.split("/")[-1].split(".")[0]
    eeg_spectrogram = spectrogram_from_eeg(file_path, counter < 1)
    all_eegs[int(eeg_id)] = eeg_spectrogram
    counter += 1
    
    

    


In [None]:
len(all_eegs2)

In [None]:
all_eegs2[test.loc[0,'eeg_id']].shape

## Model Class

#### Another thing I added in the model is parallel convolution Blocks which used different 1D CNN w/ different kernel size, to capture different window features. It helped a lot in improving my score. You can try out and experiment with different kernels, specified in "kernels" parameter in model class.


In [None]:
import torch
import torch.nn as nn


class ResNet_1D_Block(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, downsampling):
        super(ResNet_1D_Block, self).__init__()
        self.bn1 = nn.BatchNorm1d(num_features=in_channels)
        self.relu = nn.ReLU(inplace=False)
        self.dropout = nn.Dropout(p=0.0, inplace=False)
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm1d(num_features=out_channels)
        self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.downsampling = downsampling

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)

        out = self.maxpool(out)
        identity = self.downsampling(x)

        out += identity
        return out



In [None]:
import timm

class EEGMegaNet(nn.Module):

    def __init__(self, backbone_2d,in_channels_2d, kernels, pretrained=False, in_channels=20, fixed_kernel_size=17, num_classes=6):
        super(EEGMegaNet, self).__init__()
        
        self.kernels = kernels
        self.planes = 24
        self.parallel_conv = nn.ModuleList()
        self.in_channels = in_channels


        
        self.backbone_2d = timm.create_model(
            Config.backbone_2d,
            pretrained=pretrained,
            drop_rate = 0.1,
            drop_path_rate = 0.1
        
        )
        
        self.features_2d = nn.Sequential(*list(self.backbone_2d.children())[:-2] + [nn.AdaptiveAvgPool2d(1),nn.Flatten()])
        
        # nn.Sequential(
            # nn.AdaptiveAvgPool2d(1),

        for i, kernel_size in enumerate(list(self.kernels)):
            sep_conv = nn.Conv1d(in_channels=in_channels, out_channels=self.planes, kernel_size=(kernel_size),
                               stride=1, padding=0, bias=False,)
            self.parallel_conv.append(sep_conv)

        self.bn1 = nn.BatchNorm1d(num_features=self.planes)
        self.relu = nn.ReLU(inplace=False)
        self.conv1 = nn.Conv1d(in_channels=self.planes, out_channels=self.planes, kernel_size=fixed_kernel_size,
                               stride=2, padding=2, bias=False)
        self.block = self._make_resnet_layer(kernel_size=fixed_kernel_size, stride=1, padding=fixed_kernel_size//2)
        self.bn2 = nn.BatchNorm1d(num_features=self.planes)
        self.avgpool = nn.AvgPool1d(kernel_size=4, stride=4, padding=2)
        self.rnn = nn.GRU(input_size=self.in_channels, hidden_size=128, num_layers=1, bidirectional=True)
        
        self.fc1 = nn.Linear(in_features=1280, out_features=128)
        self.fc2 = nn.Linear(in_features=736, out_features=128)
        self.fc = nn.Linear(in_features=256, out_features=num_classes)

        self.fc1d = nn.Linear(in_features=128, out_features=num_classes)
        self.fc2d = nn.Linear(in_features=128, out_features=num_classes)
        
        
        self.rnn1 = nn.GRU(input_size=156, hidden_size=156, num_layers=1, bidirectional=True)

    def _make_resnet_layer(self, kernel_size, stride, blocks=8, padding=0):
        layers = []
        downsample = None
        base_width = self.planes

        for i in range(blocks):
            downsampling = nn.Sequential(
                    nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
                )
            layers.append(ResNet_1D_Block(in_channels=self.planes, out_channels=self.planes, kernel_size=kernel_size,
                                       stride=stride, padding=padding, downsampling=downsampling))

        return nn.Sequential(*layers)

    def _reshape_input(self, spec):
        """
        Reshapes input (128, 256, 8) -> (512, 512, 3) monotone image.
        """ 
        # === Get spectograms ===
        spectograms = [spec[:, :, :, i:i+1] for i in range(4)]
        spectograms = torch.cat(spectograms, dim=1)
        
        # === Get EEG spectograms ===
        eegs = [spec[:, :, :, i:i+1] for i in range(4,8)]
        eegs = torch.cat(eegs, dim=1)
        
        # === Reshape (512,512,3) ===
        spec = spectograms
            
        spec = torch.cat([spec,spec,spec], dim=3)
        spec = spec.permute(0, 3, 1, 2)
        return spec

    def forward(self, x, spec):

        spec = self._reshape_input(spec)
        spec = self.features_2d(spec)
        # print(spec.shape) #2, 1280, 16, 8
        out_sep = []

        for i in range(len(self.kernels)):
            sep = self.parallel_conv[i](x)
            out_sep.append(sep)

        out = torch.cat(out_sep, dim=2)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv1(out)  

        out = self.block(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.avgpool(out)  


        
        out = out.reshape(out.shape[0], -1)  

        rnn_out, _ = self.rnn(x.permute(0,2, 1))
        new_rnn_h = rnn_out[:, -1, :]  

        new_out = torch.cat([out, new_rnn_h], dim=1)  
        new_out = self.fc2(new_out)  
        out1d = self.fc1d(new_out)
        
        spec = self.fc1(spec)  
        out2d = self.fc2d(spec)
        
        result = torch.cat([new_out, spec], dim=1)  
        result = self.fc(result)
        
        
        return result, new_out, spec, out1d, out2d

In [None]:
import gc
iot = torch.randn(2, Config.num_channels, 10000)#.cuda()
spec = torch.randn(2, 128, 256, 8)#.cuda()

model = EEGMegaNet(backbone_2d=Config.backbone_2d,in_channels_2d=8,kernels=[3,5,7,9],pretrained=False, in_channels=Config.num_channels, fixed_kernel_size=5, num_classes=6)#.cuda()
output,_,_,_,_ = model(iot, spec)
print(output.shape)

del iot, model
gc.collect()

## Define Dataset Class & Do some Visualization

In [None]:
from scipy.signal import butter, lfilter

def quantize_data(data, classes):
    mu_x = mu_law_encoding(data, classes)
    # bins = np.linspace(-1, 1, classes)
    # quantized = np.digitize(mu_x, bins) - 1
    return mu_x#quantized

def mu_law_encoding(data, mu):
    mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)
    return mu_x

def mu_law_expansion(data, mu):
    s = np.sign(data) * (np.exp(np.abs(data) * np.log(mu + 1)) - 1) / mu
    return s

def butter_lowpass_filter(data, cutoff_freq=20, sampling_rate=200, order=4):
    nyquist = 0.5 * sampling_rate
    normal_cutoff = cutoff_freq / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    filtered_data = lfilter(b, a, data, axis=0)
    return filtered_data

class EEGDataset(torch.utils.data.Dataset):

    def __init__(self, data, eegs=None, augmentations = None, test = False): 

        self.data = data
        self.eegs = eegs
        self.augmentations = augmentations
        self.test = test
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):

        row = self.data.iloc[index]      
        data = self.eegs[row.eeg_id]

        sample = np.zeros((data.shape[0], 8))  # Assuming data.shape[1] == 8 for 8 channels
        
        # Mapping of channel names to indices
        FEAT2IDX = {'Fp1': 0, 'T3': 1, 'C3': 2, 'O1': 3, 'Fp2': 4, 'C4': 5, 'T4': 6, 'O2': 7}
        
        # Compute differences
        sample[:,0] = data[:,FEAT2IDX['Fp1']] - data[:,FEAT2IDX['T3']]
        sample[:,1] = data[:,FEAT2IDX['T3']] - data[:,FEAT2IDX['O1']]
        
        sample[:,2] = data[:,FEAT2IDX['Fp1']] - data[:,FEAT2IDX['C3']]
        sample[:,3] = data[:,FEAT2IDX['C3']] - data[:,FEAT2IDX['O1']]
        
        sample[:,4] = data[:,FEAT2IDX['Fp2']] - data[:,FEAT2IDX['C4']]
        sample[:,5] = data[:,FEAT2IDX['C4']] - data[:,FEAT2IDX['O2']]
        
        sample[:,6] = data[:,FEAT2IDX['Fp2']] - data[:,FEAT2IDX['T4']]
        sample[:,7] = data[:,FEAT2IDX['T4']] - data[:,FEAT2IDX['O2']]

        # sample = np.concatenate([sample,data],1)
        
        # Feature Engineering on sample instead of data
        
        sample = (sample - np.mean(sample, axis=0)) / np.std(sample, axis=0)

        sample = np.clip(sample, -1024, 1024)
        sample = np.nan_to_num(sample, nan=0)# / 32.0
        
        sample = butter_lowpass_filter(sample)
        sample = quantize_data(sample, 1)


        samples = torch.from_numpy(sample).float()    
        samples = samples.permute(1,0)
        if not self.test:
            label = row[TARGETS] 
            label = torch.tensor(label).float()  
            return samples, label
        else:
            return samples
# ================================

In [None]:
import torch
import numpy as np
import pandas as pd
from scipy.signal import butter, lfilter

class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, data, eegs=None, specs=None, eeg_specs=None, spec_aug = False, augmentations=None, test=False): 
        self.data = data
        self.eegs = eegs
        self.specs = specs  # Spectrograms for each ID
        self.eeg_specs = eeg_specs  # EEG spectrograms for each ID
        self.augmentations = augmentations
        self.test = test
        self.spec_aug = spec_aug
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        
        # Processing EEG signal data
        eeg_data = self.eegs[row.eeg_id]
        sample = np.zeros((eeg_data.shape[0], 8))  # Assuming eeg_data.shape[1] == 8 for 8 channels
        FEAT2IDX = {'Fp1': 0, 'T3': 1, 'C3': 2, 'O1': 3, 'Fp2': 4, 'C4': 5, 'T4': 6, 'O2': 7}
        
        # Compute differences
        for i, (start, end) in enumerate([('Fp1', 'T3'), ('T3', 'O1'), ('Fp1', 'C3'), ('C3', 'O1'), 
                                          ('Fp2', 'C4'), ('C4', 'O2'), ('Fp2', 'T4'), ('T4', 'O2')]):
            sample[:, i] = eeg_data[:, FEAT2IDX[start]] - eeg_data[:, FEAT2IDX[end]]
        
        sample = self.process_sample(sample)
        
        samples = torch.from_numpy(sample).float()
        samples = samples.permute(1, 0)
        
        # Processing spectrogram data
        spec = self.__spec_data_generation(row)
        if self.spec_aug:
            spec = self.__transform(spec) 
        
        if not self.test:
            label = row[TARGETS]  # Assuming 'TARGETS' is defined somewhere as the label column name
            label = torch.tensor(label).float()  
            return samples, spec, label
        else:
            return samples, spec
    
    def process_sample(self, sample):
        # Normalize the sample data
        sample = (sample - np.mean(sample, axis=0)) / np.std(sample, axis=0)
        sample = np.clip(sample, -1024, 1024)
        sample = np.nan_to_num(sample, nan=0)
        sample = butter_lowpass_filter(sample)
        sample = quantize_data(sample, 1)
        return sample

    def __spec_data_generation(self, row):
        """
        Generates data containing batch_size samples. This method directly
        uses class attributes for spectrograms and EEG spectrograms.
        """
        X = np.zeros((128, 256, 8), dtype='float32')   
        
        if not self.test:
            # Assuming your DataFrame has a column that combines min and max values for slicing
            r = int((row['min'] + row['max']) // 4)
        else:
            r = 0  # Adjust as necessary for test mode
        
        for region in range(4):
            img = self.specs[row.spectrogram_id][r:r+300, region*100:(region+1)*100].T
            
            # Log transform spectogram
            img = np.clip(img, np.exp(-4), np.exp(8))
            img = np.log(img)

            # Standarize per image
            ep = 1e-6
            mu = np.nanmean(img.flatten())
            std = np.nanstd(img.flatten())
            img = (img-mu)/(std+ep)
            img = np.nan_to_num(img, nan=0.0)
            X[14:-14, :, region] = img[:, 22:-22] / 2.0
        
        # Process EEG spectrogram - assuming a single channel example
        img = self.eeg_specs[row.eeg_id]
        X[:, :, 4:] = img                
        return X

    def __transform(self, img):
        transforms = A.Compose([
            A.HorizontalFlip(p=0.5),
        ])
        return transforms(image=img)['image']



In [None]:
def get_test_dls(df_test):
    ds_test = EEGDataset(
        df_test, 
        eegs=all_eegs2,
        specs=all_spectrograms,
        eeg_specs=all_eegs,
        augmentations = None,
        test = True
    )
    dl_test = DataLoader(ds_test, batch_size=Config.batch_size , shuffle=False, num_workers = 2)    
    return dl_test, ds_test

In [None]:


def show_batch(img_ds, num_items, num_rows, num_cols, EEG_IDS, predict_arr=None):
    fig = plt.figure(figsize=(12, 6))    
    img_index = [0,0]
    for index, img_index in enumerate(img_index):  # list first items
        img, spec = img_ds[img_index]        
        ax = fig.add_subplot(num_rows, num_cols, index + 1, xticks=[], yticks=[])
        if isinstance(img, torch.Tensor):
            img = img.detach().numpy()
            img = img.transpose(1,0)
        offset = 0
        for j in range(img.shape[-1]):
            if j != 0: offset -= img[:, j].min()
            ax.plot(img[:, j] + offset, label=f'feature {j+1}')
            offset += img[:, j].max() + 1  # Adding 1 for visual separation

        ax.legend()
        ax.set_title(f'EEG_Id = {EEG_IDS[img_index]}', size=14)

    plt.tight_layout()
    plt.show()


In [None]:
dummy_test = test.copy()
dummy_test

In [None]:
dl_test, ds_test = get_test_dls(dummy_test)
show_batch(ds_test, 8, 2, 4, EEG_IDS)

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, ReduceLROnPlateau, OneCycleLR

def get_optimizer(lr, params):
    model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, params), 
            lr=lr,
            weight_decay=Config.weight_decay
        )
    interval = "epoch"
    
    lr_scheduler = CosineAnnealingWarmRestarts(
                            model_optimizer, 
                            T_0=Config.epochs, 
                            T_mult=1, 
                            eta_min=1e-5, 
                            last_epoch=-1
                        )

    return {
        "optimizer": model_optimizer, 
        "lr_scheduler": {
            "scheduler": lr_scheduler,
            "interval": interval,
            "monitor": "val_loss",
            "frequency": 1
        }
    }

In [None]:
import torch.nn as nn
from torch.nn.functional import cross_entropy
import torchmetrics


In [None]:
import sklearn.metrics
import sys

In [None]:
class KLDivLossWithLogits(nn.KLDivLoss):

    def __init__(self):
        super().__init__(reduction="batchmean")

    def forward(self, y, t):
        y = nn.functional.log_softmax(y,  dim=1)
        loss = super().forward(y, t)
        return loss


## Define Pytorch Lightning Module

In [None]:
class EEGModel(pl.LightningModule):
    def __init__(self, num_classes = Config.num_classes, pretrained = Config.pretrained):
        super().__init__()
        self.num_classes = num_classes
        # self.backbone = EEGNet(kernels=[3,5,7,9], in_channels=Config.num_channels, fixed_kernel_size=5, num_classes=Config.num_classes)
        self.backbone = EEGMegaNet(backbone_2d=Config.backbone_2d,
                                   in_channels_2d=8,
                                   kernels=[3,5,7,9],pretrained=Config.pretrained,
                                   in_channels=Config.num_channels,
                                   fixed_kernel_size=5, num_classes=6)
        
        self.loss_function = KLDivLossWithLogits() #nn.KLDivLoss() #nn.BCEWithLogitsLoss() 
        self.validation_step_outputs = []
        self.lin = nn.Softmax(dim=1)
        self.best_score = 1000.0
    def forward(self,eeg, spec):
        logits = self.backbone(eeg, spec)
        # logits = self.lin(logits)
        return logits
        
    def configure_optimizers(self):
        return get_optimizer(lr=Config.LR, params=self.parameters())

    def training_step(self, batch, batch_idx):
        eeg, spec, target = batch        

        y_pred = self(eeg, spec)
        loss = self.loss_function(y_pred,target)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss        

    def validation_step(self, batch, batch_idx):
        eeg, spec, target = batch 
        # print(target)
        y_pred = self(eeg, spec)
        val_loss = self.loss_function(y_pred, target)
        self.log("val_loss", val_loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        self.validation_step_outputs.append({"val_loss": val_loss, "logits": y_pred, "targets": target})

        return {"val_loss": val_loss, "logits": y_pred, "targets": target}
    
    def train_dataloader(self):
        return self._train_dataloader 
    
    def validation_dataloader(self):
        return self._validation_dataloader
    
    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        # print(len(outputs))
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        output_val = nn.Softmax(dim=1)(torch.cat([x['logits'] for x in outputs],dim=0)).cpu().detach().numpy()
        target_val = torch.cat([x['targets'] for x in outputs],dim=0).cpu().detach().numpy()
        self.validation_step_outputs = []

        val_df = pd.DataFrame(target_val, columns = list(TARGETS))
        pred_df = pd.DataFrame(output_val, columns = list(TARGETS))

        val_df['id'] = [f'id_{i}' for i in range(len(val_df))] 
        pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))] 

        avg_score = avg_loss
        # avg_score = score(val_df, pred_df, row_id_column_name = 'id')

        if avg_score < self.best_score:
            print(f'Fold {self.fold}: Epoch {self.current_epoch} validation loss {avg_loss}')
            print(f'Fold {self.fold}: Epoch {self.current_epoch} validation KDL score {avg_score}')
            self.best_score = avg_score
            # val_df.to_csv(f'{Config.output_dir}/val_df_f{self.fold}.csv',index=False)
            # pred_df.to_csv(f'{Config.output_dir}/pred_df_f{self.fold}.csv',index=False)
        
        return {'val_loss': avg_loss,'val_cmap':avg_score}
    


In [None]:
def predict(data_loader, model):
        
    model.to('cuda')
    model.eval()    
    predictions = []
    for batch in tqdm(data_loader):

        with torch.no_grad():
            x, spec = batch
            x = x.cuda()
            spec = spec.cuda()
            
            # inputs = {key:val.reshape(val.shape[0], -1).to(config.device) for key,val in batch.items()}
            outputs,_,_,out1d,out2d = model(x, spec)
            outputs = outputs*0.5 + out1d*0.25 + out2d*0.25
            outputs = nn.Softmax(dim=1)(outputs)
        predictions.extend(outputs.detach().cpu().numpy())
    predictions = np.vstack(predictions)
    return predictions

def predict_nobatch(ds_test, model):
    
    model.to('cuda')
    model.eval()    
    predictions = []
    for en in tqdm(range(len(ds_test))):
        # print(en)
        x = ds_test[en]
        x = x.unsqueeze(0).cuda()
        # print(images.shape)
        with torch.no_grad():
            outputs = model(x)
            outputs = nn.Softmax(dim=1)(outputs)
            outputs = outputs.detach().cpu().numpy()

        predictions.append(outputs)
        
    return predictions

In [None]:
from pytorch_lightning.loggers import WandbLogger
import gc
torch.set_float32_matmul_precision('high')
def run_inference(fold_id, Config):
    print(f"Running training for fold {fold_id}...")
    logger = None
    pred_cols = [f'pred_{t}' for t in TARGETS]
    df_test = test.copy()
    dl_test, ds_test = get_test_dls(df_test)
    

    print(f"Running inference model '{Config.model_dir}/eegnet_best_loss_fold{fold_id}.ckpt'..")
    
    model = EEGModel.load_from_checkpoint(f'{Config.model_dir}/eegnet_best_loss_fold{fold_id}.ckpt',map_location='cuda:0',
                                          train_dataloader=None,validation_dataloader=None,config=Config)
    
    preds = predict(dl_test, model)  
    print(preds.shape)
    gc.collect()
    # torch.cuda.empty_cache()
    return preds
    

In [None]:
# list(train[train['fold']==0].index)
test.head(3)
sub_df = test[['eeg_id']].copy()
sub_df[TARGETS] = 0.0

In [None]:
sub_df

In [None]:
test_preds = []
for en,f in enumerate(Config.trn_folds):
    preds = run_inference(f, Config)
    test_preds.append(preds)

In [None]:
test_preds = np.mean(test_preds, 0)
test_preds.shape

In [None]:
sub_df[TARGETS] = test_preds
sub_df.to_csv('submission.csv',index=False)

In [None]:
sub_df

In [None]:
print('Sub row 0 sums to:',sub_df.iloc[0,-6:].sum())


### Here we go, submission looks good :D 