<a href="https://colab.research.google.com/github/shuaicongbaobao/Colab-140/blob/main/YMW_CNN_Model_Brain_Seizure.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import gc
import os
import random
import warnings
import numpy as np
import pandas as pd
from IPython.display import display

!pip install timm
import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import CosineAnnealingLR

from google.colab import drive
drive.mount('/content/drive')

from scipy import signal


warnings.filterwarnings('ignore', category=Warning)
gc.collect()

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.7/2.2 MB[0m [31m19.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->timm)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-non

16

In [3]:
# Configuration class containing hyperparameters and settings
class Config:
    seed = 42
    image_transform = transforms.Resize((512,512))
    batch_size = 16
    num_epochs = 9
    num_folds = 5

# Set the seed for reproducibility across multiple libraries
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(Config.seed)

# Define the 'Kullback Leibler Divergence' loss function
def KL_loss(p,q):
    epsilon=10**(-15)
    p=torch.clip(p,epsilon,1-epsilon)
    q = nn.functional.log_softmax(q,dim=1)
    return torch.mean(torch.sum(p*(torch.log(p)-q),dim=1))

# Reclaim memory no longer in use.
gc.collect()

0

In [4]:
# Load training data
train_df = pd.read_csv("/content/drive/MyDrive/hms-harmful-brain-activity-classification/train.csv")

# Define labels for classification
labels = ['seizure', 'lpd', 'gpd', 'lrda', 'grda', 'other']

# Initialize an empty DataFrame for storing features
train_feats = pd.DataFrame()

# Aggregate votes for each label and merge into train_feats DataFrame
for label in labels:
    # Group by 'spectrogram_id' and sum the votes for the current label
    group = train_df[f'{label}_vote'].groupby(train_df['spectrogram_id']).sum()

    # Create a DataFrame from the grouped data
    label_vote_sum = pd.DataFrame({'spectrogram_id': group.index, f'{label}_vote_sum': group.values})

    # Initialize train_feats with the first label or merge subsequent labels
    if label == 'seizure':
        train_feats = label_vote_sum
    else:
        train_feats = train_feats.merge(label_vote_sum, on='spectrogram_id', how='left')

# Add a column to sum all votes
train_feats['total_vote'] = 0
for label in labels:
    train_feats['total_vote'] += train_feats[f'{label}_vote_sum']

# Calculate and store the normalized vote for each label
for label in labels:
    train_feats[f'{label}_vote'] = train_feats[f'{label}_vote_sum'] / train_feats['total_vote']

# Select relevant columns for the training features
choose_cols = ['spectrogram_id']
for label in labels:
    choose_cols += [f'{label}_vote']
train_feats = train_feats[choose_cols]

# Add a column with the path to the spectrogram files
train_feats['path'] = train_feats['spectrogram_id'].apply(lambda x: "/content/drive/MyDrive/hms-harmful-brain-activity-classification/train_spectrograms/" + str(x) + ".parquet")

# Reclaim memory no longer in use.
gc.collect()

0

In [5]:
def get_batch(paths, batch_size=Config.batch_size):
    # Set a small epsilon to avoid division by zero
    eps = 1e-6

    # Initialize a list to store batch data
    batch_data = []

    # Iterate over each path in the provided paths
    for path in paths:
        # Read data from parquet file
        data = pd.read_parquet(path[0])

        # Fill missing values, remove time column, and transpose
        data = data.fillna(-1).values[:, 1:].T

        # Clip values and apply logarithmic transformation
        data = np.clip(data, np.exp(-6), np.exp(10))
        data = np.log(data)

        # Normalize the data
        data_mean = data.mean(axis=(0, 1))
        data_std = data.std(axis=(0, 1))
        data = (data - data_mean) / (data_std + eps)

        # Convert data to a PyTorch tensor and apply transformations
        data_tensor = torch.unsqueeze(torch.Tensor(data), dim=0)
        data = Config.image_transform(data_tensor)

        # Append the processed data to the batch_data list
        batch_data.append(data)

    # Stack all the batch data into a single tensor
    batch_data = torch.stack(batch_data)

    # Return the batch data
    return batch_data

In [10]:
train_df.head()

Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,1628180742,0,0.0,353733,0,0.0,127492639,42516,Seizure,3,0,0,0,0,0
1,1628180742,1,6.0,353733,1,6.0,3887563113,42516,Seizure,3,0,0,0,0,0
2,1628180742,2,8.0,353733,2,8.0,1142670488,42516,Seizure,3,0,0,0,0,0
3,1628180742,3,18.0,353733,3,18.0,2718991173,42516,Seizure,3,0,0,0,0,0
4,1628180742,4,24.0,353733,4,24.0,3080632009,42516,Seizure,3,0,0,0,0,0


In [6]:
from tqdm import tqdm
# Determine device availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Assuming train_feats is defined and contains the training features and labels
total_idx = np.arange(len(train_feats))
np.random.shuffle(total_idx)

gc.collect()
eeg_dictionary={}
eeg_val_dictionary={}
# Cross-validation loop
for fold in range(1):
    # Split data into train and test sets for this fold
    total_idx = total_idx[:600]
    test_idx = total_idx[fold * len(total_idx) // Config.num_folds:(fold + 1) * len(total_idx) // Config.num_folds]
    train_idx = np.array([idx for idx in total_idx if idx not in test_idx])

    # Initialize EfficientNet-B0 model with pretrained weights
    model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=6, in_chans=1)
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=0.001, betas=(0.5, 0.999), weight_decay=0.01)
    scheduler = CosineAnnealingLR(optimizer, T_max=Config.num_epochs)

    best_test_loss = float('inf')
    train_losses = []
    test_losses = []

    print(f"Starting training for fold {fold + 1}")

    # Training loop
    for epoch in range(Config.num_epochs):
        model.train()
        train_loss = []
        random_num = np.arange(len(train_idx))
        np.random.shuffle(random_num)
        train_idx = train_idx[random_num]


        # Iterate over batches in the training set
        for idx in tqdm(range(0, len(train_idx), Config.batch_size)):
            optimizer.zero_grad()
            train_idx1 = train_idx[idx:idx + Config.batch_size]
            train_X1_path = train_feats[['path']].iloc[train_idx1].values
            print(f"loading {idx} batch's data total {len(train_idx) // Config.batch_size} batches")
            train_X1 = get_batch(train_X1_path, batch_size=Config.batch_size)
            train_y1 = train_feats[['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']].iloc[train_idx1].values
            train_y1 = torch.Tensor(train_y1)
            train_pred = model(train_X1.to(device))
            spectrogram_id = train_feats['spectrogram_id'].iloc[train_idx1].values
            for i, spectrogram_id1 in enumerate(spectrogram_id):
                eeg_dictionary[spectrogram_id1] = train_pred[i,:].detach().cpu().numpy()

            loss = KL_loss(train_y1.to(device), train_pred)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

        epoch_train_loss = np.mean(train_loss)
        train_losses.append(epoch_train_loss)
        print(f"Epoch {epoch + 1}: Train Loss = {epoch_train_loss:.2f}")

        scheduler.step()

        # Evaluation loop
        model.eval()
        test_loss = []
        with torch.no_grad():
            for idx in range(0, len(test_idx), Config.batch_size):
                test_idx1 = test_idx[idx:idx + Config.batch_size]
                test_X1_path = train_feats[['path']].iloc[test_idx1].values
                test_X1 = get_batch(test_X1_path, batch_size=Config.batch_size)
                test_y1 = train_feats[['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']].iloc[test_idx1].values
                test_y1 = torch.Tensor(test_y1)

                test_pred = model(test_X1.to(device))
                spectrogram_id = train_feats['spectrogram_id'].iloc[test_idx1].values
                for i, spectrogram_id1 in enumerate(spectrogram_id):
                    eeg_val_dictionary[spectrogram_id1] = test_pred[i,:].detach().cpu().numpy()
                loss = KL_loss(test_y1.to(device), test_pred)
                test_loss.append(loss.item())

        epoch_test_loss = np.mean(test_loss)
        test_losses.append(epoch_test_loss)
        print(f"Epoch {epoch + 1}: Test Loss = {epoch_test_loss:.2f}")


        # Save the model if it has the best test loss so far
        if epoch_test_loss < best_test_loss:
            best_test_loss = epoch_test_loss
            torch.save(model.state_dict(), f"efficientnet_b0_fold{fold}.pth")

        gc.collect()

    print(f"Fold {fold + 1} Best Test Loss: {best_test_loss:.2f}")



Using device: cuda


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Starting training for fold 1


  0%|          | 0/30 [00:00<?, ?it/s]

loading 0 batch's data total 30 batches


  3%|▎         | 1/30 [01:49<52:48, 109.25s/it]

loading 16 batch's data total 30 batches


  7%|▋         | 2/30 [02:27<31:26, 67.37s/it] 

loading 32 batch's data total 30 batches


 10%|█         | 3/30 [03:04<24:05, 53.56s/it]

loading 48 batch's data total 30 batches


 13%|█▎        | 4/30 [03:42<20:30, 47.34s/it]

loading 64 batch's data total 30 batches


 17%|█▋        | 5/30 [04:19<18:13, 43.76s/it]

loading 80 batch's data total 30 batches


 20%|██        | 6/30 [05:24<20:24, 51.00s/it]

loading 96 batch's data total 30 batches


 23%|██▎       | 7/30 [06:00<17:40, 46.12s/it]

loading 112 batch's data total 30 batches


 27%|██▋       | 8/30 [06:37<15:51, 43.24s/it]

loading 128 batch's data total 30 batches


 30%|███       | 9/30 [07:18<14:51, 42.45s/it]

loading 144 batch's data total 30 batches


 33%|███▎      | 10/30 [07:58<13:53, 41.69s/it]

loading 160 batch's data total 30 batches


 37%|███▋      | 11/30 [08:36<12:50, 40.55s/it]

loading 176 batch's data total 30 batches


 40%|████      | 12/30 [09:14<11:54, 39.69s/it]

loading 192 batch's data total 30 batches


 43%|████▎     | 13/30 [09:49<10:51, 38.33s/it]

loading 208 batch's data total 30 batches


 47%|████▋     | 14/30 [10:24<09:58, 37.41s/it]

loading 224 batch's data total 30 batches


 50%|█████     | 15/30 [11:03<09:26, 37.76s/it]

loading 240 batch's data total 30 batches


 53%|█████▎    | 16/30 [11:40<08:45, 37.55s/it]

loading 256 batch's data total 30 batches


 57%|█████▋    | 17/30 [12:17<08:05, 37.36s/it]

loading 272 batch's data total 30 batches


 60%|██████    | 18/30 [12:52<07:21, 36.79s/it]

loading 288 batch's data total 30 batches


 63%|██████▎   | 19/30 [13:33<06:57, 37.92s/it]

loading 304 batch's data total 30 batches


 67%|██████▋   | 20/30 [14:27<07:08, 42.81s/it]

loading 320 batch's data total 30 batches


 70%|███████   | 21/30 [15:07<06:18, 42.02s/it]

loading 336 batch's data total 30 batches


 73%|███████▎  | 22/30 [15:45<05:24, 40.61s/it]

loading 352 batch's data total 30 batches


 77%|███████▋  | 23/30 [16:21<04:36, 39.48s/it]

loading 368 batch's data total 30 batches


 80%|████████  | 24/30 [17:00<03:55, 39.19s/it]

loading 384 batch's data total 30 batches


 83%|████████▎ | 25/30 [17:36<03:11, 38.34s/it]

loading 400 batch's data total 30 batches


 87%|████████▋ | 26/30 [18:13<02:31, 37.99s/it]

loading 416 batch's data total 30 batches


 90%|█████████ | 27/30 [18:49<01:51, 37.25s/it]

loading 432 batch's data total 30 batches


 93%|█████████▎| 28/30 [19:25<01:13, 36.82s/it]

loading 448 batch's data total 30 batches


 97%|█████████▋| 29/30 [20:01<00:36, 36.63s/it]

loading 464 batch's data total 30 batches


100%|██████████| 30/30 [20:35<00:00, 41.19s/it]


Epoch 1: Train Loss = 2.14
Epoch 1: Test Loss = 1.35


  0%|          | 0/30 [00:00<?, ?it/s]

loading 0 batch's data total 30 batches


  3%|▎         | 1/30 [00:00<00:24,  1.18it/s]

loading 16 batch's data total 30 batches


  7%|▋         | 2/30 [00:01<00:23,  1.18it/s]

loading 32 batch's data total 30 batches


 10%|█         | 3/30 [00:02<00:23,  1.16it/s]

loading 48 batch's data total 30 batches


 13%|█▎        | 4/30 [00:03<00:22,  1.17it/s]

loading 64 batch's data total 30 batches


 17%|█▋        | 5/30 [00:04<00:21,  1.19it/s]

loading 80 batch's data total 30 batches


 20%|██        | 6/30 [00:05<00:19,  1.20it/s]

loading 96 batch's data total 30 batches


 23%|██▎       | 7/30 [00:05<00:19,  1.19it/s]

loading 112 batch's data total 30 batches


 27%|██▋       | 8/30 [00:06<00:18,  1.19it/s]

loading 128 batch's data total 30 batches


 30%|███       | 9/30 [00:07<00:17,  1.20it/s]

loading 144 batch's data total 30 batches


 33%|███▎      | 10/30 [00:08<00:16,  1.18it/s]

loading 160 batch's data total 30 batches


 37%|███▋      | 11/30 [00:09<00:15,  1.19it/s]

loading 176 batch's data total 30 batches


 40%|████      | 12/30 [00:10<00:15,  1.19it/s]

loading 192 batch's data total 30 batches


 43%|████▎     | 13/30 [00:10<00:14,  1.20it/s]

loading 208 batch's data total 30 batches


 47%|████▋     | 14/30 [00:11<00:13,  1.20it/s]

loading 224 batch's data total 30 batches


 50%|█████     | 15/30 [00:12<00:12,  1.20it/s]

loading 240 batch's data total 30 batches


 53%|█████▎    | 16/30 [00:13<00:11,  1.19it/s]

loading 256 batch's data total 30 batches


 57%|█████▋    | 17/30 [00:14<00:10,  1.18it/s]

loading 272 batch's data total 30 batches


 60%|██████    | 18/30 [00:15<00:10,  1.17it/s]

loading 288 batch's data total 30 batches


 63%|██████▎   | 19/30 [00:16<00:09,  1.17it/s]

loading 304 batch's data total 30 batches


 67%|██████▋   | 20/30 [00:16<00:08,  1.16it/s]

loading 320 batch's data total 30 batches


 70%|███████   | 21/30 [00:17<00:07,  1.17it/s]

loading 336 batch's data total 30 batches


 73%|███████▎  | 22/30 [00:18<00:06,  1.16it/s]

loading 352 batch's data total 30 batches


 77%|███████▋  | 23/30 [00:19<00:05,  1.17it/s]

loading 368 batch's data total 30 batches


 80%|████████  | 24/30 [00:20<00:05,  1.17it/s]

loading 384 batch's data total 30 batches


 83%|████████▎ | 25/30 [00:21<00:04,  1.17it/s]

loading 400 batch's data total 30 batches


 87%|████████▋ | 26/30 [00:21<00:03,  1.18it/s]

loading 416 batch's data total 30 batches


 90%|█████████ | 27/30 [00:22<00:02,  1.18it/s]

loading 432 batch's data total 30 batches


 93%|█████████▎| 28/30 [00:23<00:01,  1.18it/s]

loading 448 batch's data total 30 batches


 97%|█████████▋| 29/30 [00:24<00:00,  1.17it/s]

loading 464 batch's data total 30 batches


100%|██████████| 30/30 [00:25<00:00,  1.18it/s]


Epoch 2: Train Loss = 0.95
Epoch 2: Test Loss = 1.48


  0%|          | 0/30 [00:00<?, ?it/s]

loading 0 batch's data total 30 batches


  3%|▎         | 1/30 [00:00<00:24,  1.17it/s]

loading 16 batch's data total 30 batches


  7%|▋         | 2/30 [00:01<00:23,  1.18it/s]

loading 32 batch's data total 30 batches


 10%|█         | 3/30 [00:02<00:22,  1.18it/s]

loading 48 batch's data total 30 batches


 13%|█▎        | 4/30 [00:03<00:22,  1.18it/s]

loading 64 batch's data total 30 batches


 17%|█▋        | 5/30 [00:04<00:21,  1.18it/s]

loading 80 batch's data total 30 batches


 20%|██        | 6/30 [00:05<00:20,  1.18it/s]

loading 96 batch's data total 30 batches


 23%|██▎       | 7/30 [00:05<00:19,  1.18it/s]

loading 112 batch's data total 30 batches


 27%|██▋       | 8/30 [00:06<00:18,  1.18it/s]

loading 128 batch's data total 30 batches


 30%|███       | 9/30 [00:07<00:18,  1.16it/s]

loading 144 batch's data total 30 batches


 33%|███▎      | 10/30 [00:08<00:17,  1.16it/s]

loading 160 batch's data total 30 batches


 37%|███▋      | 11/30 [00:09<00:16,  1.15it/s]

loading 176 batch's data total 30 batches


 40%|████      | 12/30 [00:10<00:15,  1.14it/s]

loading 192 batch's data total 30 batches


 43%|████▎     | 13/30 [00:11<00:14,  1.16it/s]

loading 208 batch's data total 30 batches


 47%|████▋     | 14/30 [00:12<00:13,  1.15it/s]

loading 224 batch's data total 30 batches


 50%|█████     | 15/30 [00:12<00:12,  1.16it/s]

loading 240 batch's data total 30 batches


 53%|█████▎    | 16/30 [00:13<00:12,  1.16it/s]

loading 256 batch's data total 30 batches


 57%|█████▋    | 17/30 [00:14<00:11,  1.17it/s]

loading 272 batch's data total 30 batches


 60%|██████    | 18/30 [00:15<00:10,  1.16it/s]

loading 288 batch's data total 30 batches


 63%|██████▎   | 19/30 [00:16<00:09,  1.17it/s]

loading 304 batch's data total 30 batches


 67%|██████▋   | 20/30 [00:17<00:08,  1.18it/s]

loading 320 batch's data total 30 batches


 70%|███████   | 21/30 [00:17<00:07,  1.18it/s]

loading 336 batch's data total 30 batches


 73%|███████▎  | 22/30 [00:18<00:06,  1.19it/s]

loading 352 batch's data total 30 batches


 77%|███████▋  | 23/30 [00:19<00:05,  1.18it/s]

loading 368 batch's data total 30 batches


 80%|████████  | 24/30 [00:20<00:05,  1.17it/s]

loading 384 batch's data total 30 batches


 83%|████████▎ | 25/30 [00:21<00:04,  1.17it/s]

loading 400 batch's data total 30 batches


 87%|████████▋ | 26/30 [00:22<00:03,  1.15it/s]

loading 416 batch's data total 30 batches


 90%|█████████ | 27/30 [00:23<00:02,  1.15it/s]

loading 432 batch's data total 30 batches


 93%|█████████▎| 28/30 [00:23<00:01,  1.17it/s]

loading 448 batch's data total 30 batches


 97%|█████████▋| 29/30 [00:24<00:00,  1.18it/s]

loading 464 batch's data total 30 batches


100%|██████████| 30/30 [00:25<00:00,  1.17it/s]


Epoch 3: Train Loss = 0.47
Epoch 3: Test Loss = 1.09


  0%|          | 0/30 [00:00<?, ?it/s]

loading 0 batch's data total 30 batches


  3%|▎         | 1/30 [00:00<00:24,  1.19it/s]

loading 16 batch's data total 30 batches


  7%|▋         | 2/30 [00:01<00:23,  1.20it/s]

loading 32 batch's data total 30 batches


 10%|█         | 3/30 [00:02<00:23,  1.16it/s]

loading 48 batch's data total 30 batches


 13%|█▎        | 4/30 [00:03<00:22,  1.14it/s]

loading 64 batch's data total 30 batches


 17%|█▋        | 5/30 [00:04<00:21,  1.14it/s]

loading 80 batch's data total 30 batches


 20%|██        | 6/30 [00:05<00:20,  1.15it/s]

loading 96 batch's data total 30 batches


 23%|██▎       | 7/30 [00:06<00:19,  1.16it/s]

loading 112 batch's data total 30 batches


 27%|██▋       | 8/30 [00:06<00:18,  1.17it/s]

loading 128 batch's data total 30 batches


 30%|███       | 9/30 [00:07<00:17,  1.18it/s]

loading 144 batch's data total 30 batches


 33%|███▎      | 10/30 [00:08<00:16,  1.18it/s]

loading 160 batch's data total 30 batches


 37%|███▋      | 11/30 [00:09<00:16,  1.19it/s]

loading 176 batch's data total 30 batches


 40%|████      | 12/30 [00:10<00:15,  1.18it/s]

loading 192 batch's data total 30 batches


 43%|████▎     | 13/30 [00:11<00:14,  1.19it/s]

loading 208 batch's data total 30 batches


 47%|████▋     | 14/30 [00:11<00:13,  1.19it/s]

loading 224 batch's data total 30 batches


 50%|█████     | 15/30 [00:12<00:12,  1.19it/s]

loading 240 batch's data total 30 batches


 53%|█████▎    | 16/30 [00:13<00:11,  1.18it/s]

loading 256 batch's data total 30 batches


 57%|█████▋    | 17/30 [00:14<00:11,  1.17it/s]

loading 272 batch's data total 30 batches


 60%|██████    | 18/30 [00:15<00:10,  1.16it/s]

loading 288 batch's data total 30 batches


 63%|██████▎   | 19/30 [00:16<00:09,  1.14it/s]

loading 304 batch's data total 30 batches


 67%|██████▋   | 20/30 [00:17<00:08,  1.14it/s]

loading 320 batch's data total 30 batches


 70%|███████   | 21/30 [00:17<00:07,  1.16it/s]

loading 336 batch's data total 30 batches


 73%|███████▎  | 22/30 [00:18<00:06,  1.16it/s]

loading 352 batch's data total 30 batches


 77%|███████▋  | 23/30 [00:19<00:06,  1.16it/s]

loading 368 batch's data total 30 batches


 80%|████████  | 24/30 [00:20<00:05,  1.17it/s]

loading 384 batch's data total 30 batches


 83%|████████▎ | 25/30 [00:21<00:04,  1.17it/s]

loading 400 batch's data total 30 batches


 87%|████████▋ | 26/30 [00:22<00:03,  1.18it/s]

loading 416 batch's data total 30 batches


 90%|█████████ | 27/30 [00:23<00:02,  1.17it/s]

loading 432 batch's data total 30 batches


 93%|█████████▎| 28/30 [00:23<00:01,  1.18it/s]

loading 448 batch's data total 30 batches


 97%|█████████▋| 29/30 [00:24<00:00,  1.17it/s]

loading 464 batch's data total 30 batches


100%|██████████| 30/30 [00:25<00:00,  1.17it/s]


Epoch 4: Train Loss = 0.34
Epoch 4: Test Loss = 1.16


  0%|          | 0/30 [00:00<?, ?it/s]

loading 0 batch's data total 30 batches


  3%|▎         | 1/30 [00:00<00:24,  1.19it/s]

loading 16 batch's data total 30 batches


  7%|▋         | 2/30 [00:01<00:23,  1.19it/s]

loading 32 batch's data total 30 batches


 10%|█         | 3/30 [00:02<00:22,  1.19it/s]

loading 48 batch's data total 30 batches


 13%|█▎        | 4/30 [00:03<00:21,  1.20it/s]

loading 64 batch's data total 30 batches


 17%|█▋        | 5/30 [00:04<00:21,  1.18it/s]

loading 80 batch's data total 30 batches


 20%|██        | 6/30 [00:05<00:20,  1.19it/s]

loading 96 batch's data total 30 batches


 23%|██▎       | 7/30 [00:05<00:19,  1.18it/s]

loading 112 batch's data total 30 batches


 27%|██▋       | 8/30 [00:06<00:18,  1.19it/s]

loading 128 batch's data total 30 batches


 30%|███       | 9/30 [00:07<00:17,  1.18it/s]

loading 144 batch's data total 30 batches


 33%|███▎      | 10/30 [00:08<00:16,  1.18it/s]

loading 160 batch's data total 30 batches


 37%|███▋      | 11/30 [00:09<00:16,  1.14it/s]

loading 176 batch's data total 30 batches


 40%|████      | 12/30 [00:10<00:15,  1.13it/s]

loading 192 batch's data total 30 batches


 43%|████▎     | 13/30 [00:11<00:14,  1.14it/s]

loading 208 batch's data total 30 batches


 47%|████▋     | 14/30 [00:12<00:14,  1.14it/s]

loading 224 batch's data total 30 batches


 50%|█████     | 15/30 [00:12<00:13,  1.15it/s]

loading 240 batch's data total 30 batches


 53%|█████▎    | 16/30 [00:13<00:11,  1.17it/s]

loading 256 batch's data total 30 batches


 57%|█████▋    | 17/30 [00:14<00:11,  1.18it/s]

loading 272 batch's data total 30 batches


 60%|██████    | 18/30 [00:15<00:10,  1.17it/s]

loading 288 batch's data total 30 batches


 63%|██████▎   | 19/30 [00:16<00:09,  1.18it/s]

loading 304 batch's data total 30 batches


 67%|██████▋   | 20/30 [00:17<00:08,  1.18it/s]

loading 320 batch's data total 30 batches


 70%|███████   | 21/30 [00:17<00:07,  1.20it/s]

loading 336 batch's data total 30 batches


 73%|███████▎  | 22/30 [00:18<00:06,  1.19it/s]

loading 352 batch's data total 30 batches


 77%|███████▋  | 23/30 [00:19<00:05,  1.17it/s]

loading 368 batch's data total 30 batches


 80%|████████  | 24/30 [00:20<00:05,  1.17it/s]

loading 384 batch's data total 30 batches


 83%|████████▎ | 25/30 [00:21<00:04,  1.17it/s]

loading 400 batch's data total 30 batches


 87%|████████▋ | 26/30 [00:22<00:03,  1.16it/s]

loading 416 batch's data total 30 batches


 90%|█████████ | 27/30 [00:23<00:02,  1.16it/s]

loading 432 batch's data total 30 batches


 93%|█████████▎| 28/30 [00:23<00:01,  1.17it/s]

loading 448 batch's data total 30 batches


 97%|█████████▋| 29/30 [00:24<00:00,  1.18it/s]

loading 464 batch's data total 30 batches


100%|██████████| 30/30 [00:25<00:00,  1.17it/s]


Epoch 5: Train Loss = 0.21
Epoch 5: Test Loss = 0.99


  0%|          | 0/30 [00:00<?, ?it/s]

loading 0 batch's data total 30 batches


  3%|▎         | 1/30 [00:00<00:24,  1.20it/s]

loading 16 batch's data total 30 batches


  7%|▋         | 2/30 [00:01<00:23,  1.19it/s]

loading 32 batch's data total 30 batches


 10%|█         | 3/30 [00:02<00:23,  1.17it/s]

loading 48 batch's data total 30 batches


 13%|█▎        | 4/30 [00:03<00:22,  1.16it/s]

loading 64 batch's data total 30 batches


 17%|█▋        | 5/30 [00:04<00:22,  1.13it/s]

loading 80 batch's data total 30 batches


 20%|██        | 6/30 [00:05<00:20,  1.15it/s]

loading 96 batch's data total 30 batches


 23%|██▎       | 7/30 [00:06<00:19,  1.15it/s]

loading 112 batch's data total 30 batches


 27%|██▋       | 8/30 [00:06<00:18,  1.17it/s]

loading 128 batch's data total 30 batches


 30%|███       | 9/30 [00:07<00:17,  1.17it/s]

loading 144 batch's data total 30 batches


 33%|███▎      | 10/30 [00:08<00:16,  1.18it/s]

loading 160 batch's data total 30 batches


 37%|███▋      | 11/30 [00:09<00:16,  1.17it/s]

loading 176 batch's data total 30 batches


 40%|████      | 12/30 [00:10<00:15,  1.18it/s]

loading 192 batch's data total 30 batches


 43%|████▎     | 13/30 [00:11<00:14,  1.19it/s]

loading 208 batch's data total 30 batches


 47%|████▋     | 14/30 [00:11<00:13,  1.18it/s]

loading 224 batch's data total 30 batches


 50%|█████     | 15/30 [00:12<00:12,  1.18it/s]

loading 240 batch's data total 30 batches


 53%|█████▎    | 16/30 [00:13<00:11,  1.17it/s]

loading 256 batch's data total 30 batches


 57%|█████▋    | 17/30 [00:14<00:11,  1.18it/s]

loading 272 batch's data total 30 batches


 60%|██████    | 18/30 [00:15<00:10,  1.16it/s]

loading 288 batch's data total 30 batches


 63%|██████▎   | 19/30 [00:16<00:09,  1.15it/s]

loading 304 batch's data total 30 batches


 67%|██████▋   | 20/30 [00:17<00:08,  1.14it/s]

loading 320 batch's data total 30 batches


 70%|███████   | 21/30 [00:18<00:07,  1.15it/s]

loading 336 batch's data total 30 batches


 73%|███████▎  | 22/30 [00:18<00:06,  1.16it/s]

loading 352 batch's data total 30 batches


 77%|███████▋  | 23/30 [00:19<00:06,  1.15it/s]

loading 368 batch's data total 30 batches


 80%|████████  | 24/30 [00:20<00:05,  1.16it/s]

loading 384 batch's data total 30 batches


 83%|████████▎ | 25/30 [00:21<00:04,  1.16it/s]

loading 400 batch's data total 30 batches


 87%|████████▋ | 26/30 [00:22<00:03,  1.17it/s]

loading 416 batch's data total 30 batches


 90%|█████████ | 27/30 [00:23<00:02,  1.16it/s]

loading 432 batch's data total 30 batches


 93%|█████████▎| 28/30 [00:24<00:01,  1.16it/s]

loading 448 batch's data total 30 batches


 97%|█████████▋| 29/30 [00:24<00:00,  1.16it/s]

loading 464 batch's data total 30 batches


100%|██████████| 30/30 [00:25<00:00,  1.17it/s]


Epoch 6: Train Loss = 0.16
Epoch 6: Test Loss = 0.93


  0%|          | 0/30 [00:00<?, ?it/s]

loading 0 batch's data total 30 batches


  3%|▎         | 1/30 [00:00<00:24,  1.18it/s]

loading 16 batch's data total 30 batches


  7%|▋         | 2/30 [00:01<00:23,  1.17it/s]

loading 32 batch's data total 30 batches


 10%|█         | 3/30 [00:02<00:22,  1.17it/s]

loading 48 batch's data total 30 batches


 13%|█▎        | 4/30 [00:03<00:22,  1.18it/s]

loading 64 batch's data total 30 batches


 17%|█▋        | 5/30 [00:04<00:21,  1.17it/s]

loading 80 batch's data total 30 batches


 20%|██        | 6/30 [00:05<00:20,  1.16it/s]

loading 96 batch's data total 30 batches


 23%|██▎       | 7/30 [00:05<00:19,  1.17it/s]

loading 112 batch's data total 30 batches


 27%|██▋       | 8/30 [00:06<00:18,  1.16it/s]

loading 128 batch's data total 30 batches


 30%|███       | 9/30 [00:07<00:18,  1.16it/s]

loading 144 batch's data total 30 batches


 33%|███▎      | 10/30 [00:08<00:17,  1.16it/s]

loading 160 batch's data total 30 batches


 37%|███▋      | 11/30 [00:09<00:16,  1.15it/s]

loading 176 batch's data total 30 batches


 40%|████      | 12/30 [00:10<00:15,  1.13it/s]

loading 192 batch's data total 30 batches


 43%|████▎     | 13/30 [00:11<00:14,  1.15it/s]

loading 208 batch's data total 30 batches


 47%|████▋     | 14/30 [00:12<00:13,  1.17it/s]

loading 224 batch's data total 30 batches


 50%|█████     | 15/30 [00:12<00:12,  1.17it/s]

loading 240 batch's data total 30 batches


 53%|█████▎    | 16/30 [00:13<00:11,  1.18it/s]

loading 256 batch's data total 30 batches


 57%|█████▋    | 17/30 [00:14<00:11,  1.18it/s]

loading 272 batch's data total 30 batches


 60%|██████    | 18/30 [00:15<00:10,  1.16it/s]

loading 288 batch's data total 30 batches


 63%|██████▎   | 19/30 [00:16<00:09,  1.17it/s]

loading 304 batch's data total 30 batches


 67%|██████▋   | 20/30 [00:17<00:08,  1.18it/s]

loading 320 batch's data total 30 batches


 70%|███████   | 21/30 [00:17<00:07,  1.18it/s]

loading 336 batch's data total 30 batches


 73%|███████▎  | 22/30 [00:18<00:06,  1.16it/s]

loading 352 batch's data total 30 batches


 77%|███████▋  | 23/30 [00:19<00:06,  1.16it/s]

loading 368 batch's data total 30 batches


 80%|████████  | 24/30 [00:20<00:05,  1.16it/s]

loading 384 batch's data total 30 batches


 83%|████████▎ | 25/30 [00:21<00:04,  1.16it/s]

loading 400 batch's data total 30 batches


 87%|████████▋ | 26/30 [00:22<00:03,  1.16it/s]

loading 416 batch's data total 30 batches


 90%|█████████ | 27/30 [00:23<00:02,  1.15it/s]

loading 432 batch's data total 30 batches


 93%|█████████▎| 28/30 [00:24<00:01,  1.17it/s]

loading 448 batch's data total 30 batches


 97%|█████████▋| 29/30 [00:24<00:00,  1.18it/s]

loading 464 batch's data total 30 batches


100%|██████████| 30/30 [00:25<00:00,  1.17it/s]


Epoch 7: Train Loss = 0.12
Epoch 7: Test Loss = 0.96


  0%|          | 0/30 [00:00<?, ?it/s]

loading 0 batch's data total 30 batches


  3%|▎         | 1/30 [00:00<00:24,  1.16it/s]

loading 16 batch's data total 30 batches


  7%|▋         | 2/30 [00:01<00:24,  1.15it/s]

loading 32 batch's data total 30 batches


 10%|█         | 3/30 [00:02<00:23,  1.15it/s]

loading 48 batch's data total 30 batches


 13%|█▎        | 4/30 [00:03<00:22,  1.15it/s]

loading 64 batch's data total 30 batches


 17%|█▋        | 5/30 [00:04<00:21,  1.14it/s]

loading 80 batch's data total 30 batches


 20%|██        | 6/30 [00:05<00:20,  1.16it/s]

loading 96 batch's data total 30 batches


 23%|██▎       | 7/30 [00:06<00:19,  1.16it/s]

loading 112 batch's data total 30 batches


 27%|██▋       | 8/30 [00:06<00:18,  1.17it/s]

loading 128 batch's data total 30 batches


 30%|███       | 9/30 [00:07<00:17,  1.18it/s]

loading 144 batch's data total 30 batches


 33%|███▎      | 10/30 [00:08<00:16,  1.18it/s]

loading 160 batch's data total 30 batches


 37%|███▋      | 11/30 [00:09<00:16,  1.18it/s]

loading 176 batch's data total 30 batches


 40%|████      | 12/30 [00:10<00:15,  1.18it/s]

loading 192 batch's data total 30 batches


 43%|████▎     | 13/30 [00:11<00:14,  1.18it/s]

loading 208 batch's data total 30 batches


 47%|████▋     | 14/30 [00:11<00:13,  1.19it/s]

loading 224 batch's data total 30 batches


 50%|█████     | 15/30 [00:12<00:12,  1.19it/s]

loading 240 batch's data total 30 batches


 53%|█████▎    | 16/30 [00:13<00:11,  1.18it/s]

loading 256 batch's data total 30 batches


 57%|█████▋    | 17/30 [00:14<00:11,  1.18it/s]

loading 272 batch's data total 30 batches


 60%|██████    | 18/30 [00:15<00:10,  1.17it/s]

loading 288 batch's data total 30 batches


 63%|██████▎   | 19/30 [00:16<00:09,  1.16it/s]

loading 304 batch's data total 30 batches


 67%|██████▋   | 20/30 [00:17<00:08,  1.15it/s]

loading 320 batch's data total 30 batches


 70%|███████   | 21/30 [00:17<00:07,  1.16it/s]

loading 336 batch's data total 30 batches


 73%|███████▎  | 22/30 [00:18<00:06,  1.16it/s]

loading 352 batch's data total 30 batches


 77%|███████▋  | 23/30 [00:19<00:06,  1.16it/s]

loading 368 batch's data total 30 batches


 80%|████████  | 24/30 [00:20<00:05,  1.16it/s]

loading 384 batch's data total 30 batches


 83%|████████▎ | 25/30 [00:21<00:04,  1.16it/s]

loading 400 batch's data total 30 batches


 87%|████████▋ | 26/30 [00:22<00:03,  1.16it/s]

loading 416 batch's data total 30 batches


 90%|█████████ | 27/30 [00:23<00:02,  1.18it/s]

loading 432 batch's data total 30 batches


 93%|█████████▎| 28/30 [00:23<00:01,  1.18it/s]

loading 448 batch's data total 30 batches


 97%|█████████▋| 29/30 [00:24<00:00,  1.17it/s]

loading 464 batch's data total 30 batches


100%|██████████| 30/30 [00:25<00:00,  1.17it/s]


Epoch 8: Train Loss = 0.09
Epoch 8: Test Loss = 0.96


  0%|          | 0/30 [00:00<?, ?it/s]

loading 0 batch's data total 30 batches


  3%|▎         | 1/30 [00:00<00:24,  1.19it/s]

loading 16 batch's data total 30 batches


  7%|▋         | 2/30 [00:01<00:23,  1.17it/s]

loading 32 batch's data total 30 batches


 10%|█         | 3/30 [00:02<00:22,  1.18it/s]

loading 48 batch's data total 30 batches


 13%|█▎        | 4/30 [00:03<00:22,  1.18it/s]

loading 64 batch's data total 30 batches


 17%|█▋        | 5/30 [00:04<00:21,  1.17it/s]

loading 80 batch's data total 30 batches


 20%|██        | 6/30 [00:05<00:20,  1.17it/s]

loading 96 batch's data total 30 batches


 23%|██▎       | 7/30 [00:05<00:19,  1.17it/s]

loading 112 batch's data total 30 batches


 27%|██▋       | 8/30 [00:06<00:18,  1.17it/s]

loading 128 batch's data total 30 batches


 30%|███       | 9/30 [00:07<00:18,  1.16it/s]

loading 144 batch's data total 30 batches


 33%|███▎      | 10/30 [00:08<00:17,  1.15it/s]

loading 160 batch's data total 30 batches


 37%|███▋      | 11/30 [00:09<00:16,  1.13it/s]

loading 176 batch's data total 30 batches


 40%|████      | 12/30 [00:10<00:15,  1.14it/s]

loading 192 batch's data total 30 batches


 43%|████▎     | 13/30 [00:11<00:14,  1.15it/s]

loading 208 batch's data total 30 batches


 47%|████▋     | 14/30 [00:12<00:13,  1.15it/s]

loading 224 batch's data total 30 batches


 50%|█████     | 15/30 [00:12<00:12,  1.16it/s]

loading 240 batch's data total 30 batches


 53%|█████▎    | 16/30 [00:13<00:11,  1.17it/s]

loading 256 batch's data total 30 batches


 57%|█████▋    | 17/30 [00:14<00:11,  1.17it/s]

loading 272 batch's data total 30 batches


 60%|██████    | 18/30 [00:15<00:10,  1.16it/s]

loading 288 batch's data total 30 batches


 63%|██████▎   | 19/30 [00:16<00:09,  1.17it/s]

loading 304 batch's data total 30 batches


 67%|██████▋   | 20/30 [00:17<00:08,  1.17it/s]

loading 320 batch's data total 30 batches


 70%|███████   | 21/30 [00:18<00:07,  1.17it/s]

loading 336 batch's data total 30 batches


 73%|███████▎  | 22/30 [00:18<00:06,  1.17it/s]

loading 352 batch's data total 30 batches


 77%|███████▋  | 23/30 [00:19<00:06,  1.16it/s]

loading 368 batch's data total 30 batches


 80%|████████  | 24/30 [00:20<00:05,  1.16it/s]

loading 384 batch's data total 30 batches


 83%|████████▎ | 25/30 [00:21<00:04,  1.16it/s]

loading 400 batch's data total 30 batches


 87%|████████▋ | 26/30 [00:22<00:03,  1.16it/s]

loading 416 batch's data total 30 batches


 90%|█████████ | 27/30 [00:23<00:02,  1.15it/s]

loading 432 batch's data total 30 batches


 93%|█████████▎| 28/30 [00:24<00:01,  1.16it/s]

loading 448 batch's data total 30 batches


 97%|█████████▋| 29/30 [00:24<00:00,  1.17it/s]

loading 464 batch's data total 30 batches


100%|██████████| 30/30 [00:25<00:00,  1.16it/s]


Epoch 9: Train Loss = 0.08
Epoch 9: Test Loss = 0.94
Fold 1 Best Test Loss: 0.93


In [29]:
test_pred


tensor([[-4.4880, -3.4995, -0.9396, -0.3353, -0.4375,  5.3139],
        [-0.3177, -1.7164, -1.1957,  0.4834, -2.7200,  3.2143],
        [-3.6905, -3.1251,  0.8802,  0.1769, -1.5775,  3.9135],
        [-1.6302, -0.6982, -0.1348,  0.8040, -1.0798,  2.0660],
        [ 0.3398, -1.9518,  2.8254, -0.8887, -2.0148,  0.8526],
        [-2.3926, -3.0899,  3.8636, -1.9264,  1.8529,  0.3654],
        [-1.5579, -2.6839, -0.5173,  0.2610, -1.3064,  5.4070],
        [ 1.1406, -0.6360, -1.6027, -0.2346,  0.7522,  4.7232]],
       device='cuda:0')

In [7]:
import pandas as pd
pd.DataFrame.from_dict({k:{"y_0": v[0], "y_1": v[1], "y_2": v[2], "y_3": v[3], "y_4": v[4], "y_5": v[5]} for k, v in eeg_dictionary.items()},orient="index").to_csv("cnn_train.csv")
pd.DataFrame.from_dict({k:{"y_0": v[0], "y_1": v[1], "y_2": v[2], "y_3": v[3], "y_4": v[4], "y_5": v[5]} for k, v in eeg_val_dictionary.items()},orient="index").to_csv("cnn_validation.csv")