In [1]:
%load_ext autoreload
%autoreload 2
from convnet import ConvDipNet
from torchinfo import summary
import torch.nn as nn
import numpy as np
import torch
import sys; sys.path.insert(0, '../')
from esinet.forward import create_forward_model, get_info
from CNN_LSTM.util import *
from dipoleDataset import DipoleDataset
import os
from esinet.evaluate import eval_auc, eval_nmse, eval_mse, eval_mean_localization_error
import json
from util import solve_p

2025-11-29 21:55:35.912600: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-29 21:55:35.967707: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# define hyperparameters
in_channels = 1
im_shape = (9,9)
n_filters = 8
kernel_size = (3,3)

# create single input ConvDipNet 
convnet: nn.Module  = ConvDipNet(in_channels, im_shape, n_filters, kernel_size)


# print model summary
summary(convnet, input_size=(32, 1, im_shape[0], im_shape[1])) # (batch_size, n_timesteps, in_channels, height, width)

Layer (type:depth-idx)                   Output Shape              Param #
ConvDipNet                               [32, 5124]                --
├─Conv2d: 1-1                            [32, 8, 9, 9]             80
├─BatchNorm2d: 1-2                       [32, 8, 9, 9]             16
├─Conv2d: 1-3                            [32, 8, 9, 9]             584
├─BatchNorm2d: 1-4                       [32, 8, 9, 9]             16
├─Dropout: 1-5                           [32, 8, 9, 9]             --
├─Linear: 1-6                            [32, 512]                 332,288
├─BatchNorm1d: 1-7                       [32, 512]                 1,024
├─Linear: 1-8                            [32, 5124]                2,628,612
Total params: 2,962,620
Trainable params: 2,962,620
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 96.50
Input size (MB): 0.01
Forward/backward pass size (MB): 2.24
Params size (MB): 11.85
Estimated Total Size (MB): 14.10

In [3]:
model_dir = "/mnt/data/convdip/model/convdip_run6"
model_weight_path = os.path.join(model_dir, "convdip_40.pt")
convnet.load_state_dict(torch.load(model_weight_path, weights_only=True))
convnet.eval()

ConvDipNet(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.3, inplace=False)
  (hidden_layer): Linear(in_features=648, out_features=512, bias=True)
  (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (output_layer): Linear(in_features=512, out_features=5124, bias=True)
)

In [4]:
model_save_path = "/mnt/data/convdip/model/"
loss_save_path = "/mnt/data/convdip/model/convdip_loss.npy"
data_path = "/mnt/data/convdip/training_data/"
eeg_data_path = os.path.join(data_path, "eeg_data")
interp_data_path = os.path.join(data_path, "interp_data")
source_data_path = os.path.join(data_path, "source_data")
info_path = os.path.join(data_path, "info.fif")

dataset = DipoleDataset(eeg_data_path, interp_data_path, source_data_path, info_path, im_shape=im_shape)
test_size = 0.15
val_size = 0.15

test_amount, val_amount = int(dataset.__len__() * test_size), int(dataset.__len__() * val_size)

# this function will automatically randomly split your dataset but you could also implement the split yourself
gen = torch.Generator()
gen.manual_seed(0) # this is the seed we use to split the data the same way each time
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [
            (dataset.__len__() - (test_amount + val_amount)), 
            test_amount, 
            val_amount
], generator=gen)

B = 512  # batch size
train_dataloader = torch.utils.data.DataLoader(
            train_set,
            batch_size=B,
            shuffle=False,
)
val_dataloader = torch.utils.data.DataLoader(
            val_set,
            batch_size=B,
            shuffle=False,
)
test_dataloader = torch.utils.data.DataLoader(
            test_set,
            batch_size=B,
            shuffle=False,
)

In [5]:
dipole_pos = np.load(os.path.join(data_path, "dipole_pos.npy"))

idx, sample, target = val_dataloader.dataset[0]
print(idx)
print(sample.shape)


24175
torch.Size([1, 9, 9])


In [6]:
# create forward model
fs = 100
info = get_info(sfreq=fs)
fwd = create_forward_model(sampling='ico4', info=info)
leadfield = fwd['sol']['data']

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 80 concurrent workers.
[Parallel(n_jobs=-1)]: Done  18 out of  80 | elapsed:    5.9s remaining:   20.2s
[Parallel(n_jobs=-1)]: Done  35 out of  80 | elapsed:    6.0s remaining:    7.7s
[Parallel(n_jobs=-1)]: Done  52 out of  80 | elapsed:    6.0s remaining:    3.2s
[Parallel(n_jobs=-1)]: Done  69 out of  80 | elapsed:    6.0s remaining:    1.0s
[Parallel(n_jobs=-1)]: Done  80 out of  80 | elapsed:    6.1s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 80 concurrent workers.
[Parallel(n_jobs=-1)]: Done  18 out of  80 | elapsed:    0.1s remaining:    0.3s
[Parallel(n_jobs=-1)]: Done  35 out of  80 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done  52 out of  80 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done  69 out of  80 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done  80 out of  80 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 80

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
metric_save_path = os.path.join(model_dir, "evaluation_metrics.json")
convnet.to(device)
with torch.no_grad():
    metrics_per_sample = {}
    i=0
    for idxs, batch, target in test_dataloader:
        i+=1
        print(i)
        batch = batch.to(device, dtype=torch.float)
        output = convnet(batch)
        output = output.cpu()

        for idx in tqdm(range(output.shape[0]), position=0, desc="sample"):
            data_idx = int(idxs[idx])
            target_sample = np.array(target[idx])
            output_sample = np.array(output[idx])
            
            eeg = np.load(os.path.join(data_path, f"eeg_data/sample_{data_idx}.npy"))
            max_idx = np.unravel_index(np.argmax(eeg), eeg.shape)[1] # this is the timestep with the maximum eeg value, this will be used to train
            output_sample = solve_p(output_sample, eeg[:,max_idx], leadfield)

            
            auc_close, auc_far = eval_auc(target_sample, output_sample, dipole_pos)
            sample_auc = (auc_close + auc_far)/2
            
            mle = eval_mean_localization_error(target_sample, output_sample, dipole_pos)
            mse = eval_mse(target_sample, output_sample)
            nmse = eval_nmse(target_sample, output_sample)
            metrics_per_sample[data_idx] = [auc_close, auc_far, mle, mse, nmse]
            
        
       # with open(metric_save_path, "w") as json_file:
            #json.dump(metrics_per_sample, json_file)
        

Using device: cuda
1


sample:   5%|▍         | 25/512 [01:53<36:59,  4.56s/it] 


KeyboardInterrupt: 

In [28]:
with open(metric_save_path, 'r') as json_file:
    metrics = json.load(json_file)
print(metrics)

{'9771': [0.4145663988657845, 0.29337429111531194, 36.312774929301, 0.009217227703145066, 0.13296913452703557], '4374': [0.22827919333413837, 0.08673348629392585, 35.02693422038831, 0.01775956286581762, 0.21168121131165538], '54679': [0.32557755102040814, 0.11291224489795919, nan, 0.014980096009553266, 0.14347386502750029]}
