In [9]:
%load_ext autoreload
%autoreload 2
from convnet import ConvDipNet
from timeDistributed import TimeDistributed
from torchinfo import summary
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch
import sys; sys.path.insert(0, '../')
from esinet.forward import create_forward_model, get_info
from esinet import Simulation
from copy import deepcopy
from CNN_LSTM.util import *
from dipoleDataset import DipoleDataset
import os
import mne
from esinet.evaluate import eval_auc, eval_nmse, eval_mse, eval_mean_localization_error


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# define hyperparameters
in_channels = 1
im_shape = (9,9)
n_filters = 8
kernel_size = (3,3)

# create single input ConvDipNet 
convnet: nn.Module  = ConvDipNet(in_channels, im_shape, n_filters, kernel_size)

# create TimeDistributed ConvDipNet to process all samples of timeseries at onceb
time_distributed_convnet: nn.Module = TimeDistributed(convnet, batch_first=True) # change batch_first to False for now for evaluation, will change back later

# print model summary
summary(time_distributed_convnet, input_size=(32, 100, 1, im_shape[0], im_shape[1])) # (batch_size, n_timesteps, in_channels, height, width)

Layer (type:depth-idx)                   Output Shape              Param #
TimeDistributed                          [32, 100, 5124]           --
├─ConvDipNet: 1-1                        [3200, 5124]              --
│    └─Conv2d: 2-1                       [3200, 8, 9, 9]           80
│    └─BatchNorm2d: 2-2                  [3200, 8, 9, 9]           16
│    └─Linear: 2-3                       [3200, 512]               332,288
│    └─BatchNorm1d: 2-4                  [3200, 512]               1,024
│    └─Linear: 2-5                       [3200, 5124]              2,628,612
Total params: 2,962,020
Trainable params: 2,962,020
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 9.50
Input size (MB): 1.04
Forward/backward pass size (MB): 190.57
Params size (MB): 11.85
Estimated Total Size (MB): 203.45

In [3]:
model_weight_path = "/mnt/data/convdip/model/convdip.pt"
time_distributed_convnet.load_state_dict(torch.load(model_weight_path, weights_only=True))
time_distributed_convnet.eval()

TimeDistributed(
  (module): ConvDipNet(
    (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (hidden_layer): Linear(in_features=648, out_features=512, bias=True)
    (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (output_layer): Linear(in_features=512, out_features=5124, bias=True)
  )
)

In [11]:
data_dir = "/mnt/data/convdip/training_data/"
eeg_data_dir = os.path.join(data_dir, "eeg_data")
interp_data_dir = os.path.join(data_dir, "interp_data")
source_data_dir = os.path.join(data_dir, "source_data")
info_path = os.path.join(data_dir, "info.fif")
dataset = DipoleDataset(eeg_data_dir, interp_data_dir, source_data_dir, info_path, im_shape=im_shape)
test_size = 0.15
val_size = 0.15

test_amount, val_amount = int(dataset.__len__() * test_size), int(dataset.__len__() * val_size)

# this function will automatically randomly split your dataset but you could also implement the split yourself
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [
            (dataset.__len__() - (test_amount + val_amount)), 
            test_amount, 
            val_amount
])

In [12]:
B = 256  # batch size
train_dataloader = torch.utils.data.DataLoader(
            train_set,
            batch_size=B,
            shuffle=True,
)
val_dataloader = torch.utils.data.DataLoader(
            val_set,
            batch_size=B,
            shuffle=True,
)
test_dataloader = torch.utils.data.DataLoader(
            test_set,
            batch_size=B,
            shuffle=True,
)

In [None]:
sample, target = train_dataloader.
print(sample.shape)
print(target.shape)

torch.Size([1, 9, 9])
torch.Size([5124])


In [8]:
with torch.no_grad():
    all_auc = []
    all_mse = []
    all_nmse = []
    all_mle = []
    for batch, target in tqdm(val_dataloader, position=0, desc="batch"):
        batch = batch.cuda()
        output = time_distributed_convnet(batch)
        output = output.cpu()
        sample_auc = 0
        sample_mle = 0
        sample_mse = 0
        sample_nmse = 0
        num_timesteps = target.shape[1]
        for idx in tqdm(range(output.shape[0]), position=0, desc="sample"):
            sample = batch[idx]
            max_timestep = (sample==torch.max(sample)).nonzero()[0][0]
            sample_target = np.array(target[idx, max_timestep, :])
            sample_output = np.array(output[idx, max_timestep, :])
            
            auc_close, auc_far = eval_auc(sample_target, sample_output, dipole_pos)
            sample_auc += auc_close + auc_far
            
            mle = eval_mean_localization_error(sample_target, sample_output, dipole_pos)
            mse = eval_mse(sample_target, sample_output)
            nmse = eval_nmse(sample_target, sample_output)

        sample_auc = sample_auc/(2*num_timesteps*batch.shape[0])
        sample_mle = sample_mle/(num_timesteps*batch.shape[0])
        sample_mse = sample_mse/(num_timesteps*batch.shape[0])
        sample_nmse = sample_nmse/(num_timesteps*batch.shape[0])
        all_auc.append(sample_auc)
        all_mle.append(sample_mle)
        all_mse.append(sample_mse)
        all_nmse.append(sample_nmse)

  y_est_normed = y_est / np.max(np.abs(y_est))
  y_est_normed = y_est / np.max(np.abs(y_est))
  y_est_normed = y_est / np.max(np.abs(y_est))
  y_est_normed = y_est / np.max(np.abs(y_est))
  y_est_normed = y_est / np.max(np.abs(y_est))
sample:  12%|█▎        | 32/256 [39:19<4:35:18, 73.74s/it]
batch:   0%|          | 0/59 [39:21<?, ?it/s]


KeyboardInterrupt: 