# Import

In [1]:
import os
import re
import gc
import sys
from loguru import logger

import matplotlib.pyplot as plt 

from datetime import datetime

import numpy as np
import torch


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))
from HETSFileHelper import gatherCSV, readChannel, EIS_recal_ver02
from Outlier import OutlierDetection
from EISGPR import Interpolation


# %matplotlib qt

# Filesys

In [2]:
def SearchELE(rootPath, ele_pattern = re.compile(r"(.+?)_归档")):
    '''==================================================
        Search all electrode directories in the rootPath
        Parameter: 
            rootPath: current search path
            ele_pattern: electrode dir name patten
        Returen:
            ele_list: list of electrode directories
        ==================================================
    '''
    ele_list = []
    for i in os.listdir(rootPath):
        _path = os.path.join(rootPath, i)
        if os.path.isdir(_path):
            match_ele = ele_pattern.match(i)
            if match_ele:
                ele_list.append([_path, match_ele.group(1)])
            else:
                ele_list.extend(SearchELE(_path, ele_pattern))

    return ele_list

In [110]:
# rootPath = "D:/Baihm/EISNN/Archive/"
# ele_list = SearchELE(rootPath)


rootPath = "D:/Baihm/EISNN/Archive_New/"
ele_list = SearchELE(rootPath)

# rootPath = "D:/Baihm/EISNN/Invivo/"
# ele_list = SearchELE(rootPath, re.compile(r"(.+?)_Ver02"))

n_ele = len(ele_list)
logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")

[32m2025-06-11 15:55:45.221[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1mSearch in D:/Baihm/EISNN/Archive_New/ and find 187 electrodes[0m


# Error Processed Statistic

In [111]:
# 我们观察到，由于我们在最后聚类的时候使用了AP + silhouette_score
# 而silhouette_score 对最低样本数有要求
# 这使得我们会遇到大量报错，之前用try exception跳过了，但是这个可能会导致我们把正常电极误判
# 这里我们打印每个pt文件中，有效电极数和追踪天数
# 如果有效电极数 < 128 - 10 且追踪天数比较多，就认为有问题

DATASET_SUFFIX = "Outlier_Ver04"

n_miss_ele      = 0
n_avaliable_ele = 0


n_all_days      = []



n_few_error     = []
n_open_error    = []
n_nan_error     = []
n_good          = []



for i in range(n_ele):
# for i in range(3):
    # logger.info(f"ELE Begin: {ele_list[i][0]}")
    fd_pt = os.path.join(ele_list[i][0], DATASET_SUFFIX, f"{ele_list[i][1]}_{DATASET_SUFFIX}.pt")
    if not os.path.exists(fd_pt):
        n_miss_ele = n_miss_ele + 1
        logger.warning(f"{fd_pt} does not exist")
        continue
    data_pt = torch.load(fd_pt)
    _meta_group = data_pt["meta_group"]
    _data_group = data_pt["data_group"]


    n_day = _meta_group["n_day"]    
    n_ch = _meta_group["n_ch"]         



    
    ch_few_error = _meta_group["ch_few_error"]  
    ch_open_error = _meta_group["ch_open_error"] 
    ch_nan_error = _meta_group["ch_nan_error"]  
    ch_good = _meta_group["ch_good"]       

    n_avaliable_ele = n_avaliable_ele + len(ch_good)
    n_all_days.append(n_day)
    n_few_error.append(len(ch_few_error))
    n_open_error.append(len(ch_open_error))
    n_nan_error.append(len(ch_nan_error))
    n_good.append(len(ch_good))

    

    logger.info(f"{ele_list[i][1]}[{i:03d}] - [{n_day}]: Error:{len(ch_few_error)} Open:{len(ch_open_error)} Nan:{len(ch_nan_error)} Good:{len(ch_good)}/{n_ch} ")


  data_pt = torch.load(fd_pt)
[32m2025-06-11 15:55:50.191[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m57[0m - [1m02027452[000] - [27]: Error:0 Open:41 Nan:0 Good:87/128 [0m
[32m2025-06-11 15:55:53.225[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m57[0m - [1m02027453[001] - [27]: Error:0 Open:46 Nan:0 Good:82/128 [0m
[32m2025-06-11 15:55:54.814[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m57[0m - [1m11037287[006] - [9]: Error:17 Open:6 Nan:0 Good:105/128 [0m
[32m2025-06-11 15:55:57.630[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m57[0m - [1m16057219[010] - [15]: Error:0 Open:6 Nan:0 Good:122/128 [0m
[32m2025-06-11 15:56:01.687[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m57[0m - [1m16057220[011] - [22]: Error:0 Open:9 Nan:0 Good:119/128 [0m
[32m2025-06-11 15:56:04.309[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m57[0m - [1m16057221[012] - [

In [112]:
n_all_days = np.array(n_all_days)
n_few_error = np.array(n_few_error)
n_open_error = np.array(n_open_error)
n_nan_error = np.array(n_nan_error)
n_good = np.array(n_good)

In [113]:
cnt_few_error  =  n_few_error  
cnt_open_error =  n_open_error
cnt_nan_error  =  n_nan_error 
cnt_good       =  n_good      
logger.info(f"\n cnt_few_error:{np.sum(cnt_few_error)}\
            \ncnt_open_error:{np.sum(cnt_open_error)}\
            \ncnt_nan_error:{np.sum(cnt_nan_error)}\
            \ncnt_good:{np.sum(cnt_good)}\
            \nsum:{np.sum(cnt_few_error)+np.sum(cnt_open_error)+np.sum(cnt_nan_error)+np.sum(cnt_good)}")

[32m2025-06-11 15:59:00.567[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
 cnt_few_error:2691            
cnt_open_error:3557            
cnt_nan_error:0            
cnt_good:9644            
sum:15892[0m


In [114]:
cnt_few_error  =  n_all_days * n_few_error  
cnt_open_error =  n_all_days * n_open_error
cnt_nan_error  =  n_all_days * n_nan_error 
cnt_good       =  n_all_days * n_good      
logger.info(f"\n cnt_few_error:{np.sum(cnt_few_error)}\
            \ncnt_open_error:{np.sum(cnt_open_error)}\
            \ncnt_nan_error:{np.sum(cnt_nan_error)}\
            \ncnt_good:{np.sum(cnt_good)}\
            \nsum:{np.sum(cnt_few_error)+np.sum(cnt_open_error)+np.sum(cnt_nan_error)+np.sum(cnt_good)}")

[32m2025-06-11 15:59:12.817[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
 cnt_few_error:16020            
cnt_open_error:46380            
cnt_nan_error:0            
cnt_good:105836            
sum:168236[0m


In [None]:
print(n_ele,n_miss_ele)

218 65


In [107]:
import matplotlib.pyplot as plt
import numpy as np

# 设置数据
labels = ['In vitro', 'In vivo']
categories = ['Good', 'Abnormal', 'Open']
main_colors = ['green', 'orange', 'red']
light_colors = ['lightgreen', 'khaki', 'lightcoral']

# 主数据集（前面给出的）
main_counts = {
    'Good':     [229803, 13707],
    'Abnormal':     [30585, 245],
    'Open': [102196, 0],
}

# 对比数据集（新提供的）
comp_counts = {
    'Good':     [21878, 719],
    'Abnormal':     [5013, 49],
    'Open': [8432, 0],
}

# 计算总和用于比例
main_totals = [sum([main_counts[cat][i] for cat in categories]) for i in range(2)]
comp_totals = [sum([comp_counts[cat][i] for cat in categories]) for i in range(2)]

# 计算比例
main_ratios = {cat: [main_counts[cat][i] / main_totals[i] for i in range(2)] for cat in categories}
comp_ratios = {cat: [comp_counts[cat][i] / comp_totals[i] for i in range(2)] for cat in categories}

# 准备绘图
x = np.arange(len(labels))  # [0, 1] for 'in vitro', 'in vivo'
width = 0.2  # 每根柱子宽度

fig, ax = plt.subplots(figsize=(4, 6))

width_weight = 0.4

for i, cat in enumerate(categories):
    # 左右偏移
    main_pos = x + ((i-1) * width)  # 主数据左边三根柱子
    comp_pos = main_pos + width_weight*width       # 对比数据右边三根柱子

    # 主数据柱状图
    main_bars = ax.bar(
        main_pos, main_ratios[cat], width=width_weight*width, color=main_colors[i], label=f"#{cat} (samples)", zorder=3
    )
    comp_bars = ax.bar(
        comp_pos, comp_ratios[cat], width=width_weight*width, color=light_colors[i], label=f"#{cat} (electrodes)", zorder=3
    )

    # 添加文字标签（主数据）
    for j in range(len(main_bars)):
        height = main_bars[j].get_height()
        if height <= 1.0:
            ax.text(
                main_bars[j].get_x() + main_bars[j].get_width()/2 , height + 0.02,
                f"{main_counts[cat][j]}",
                ha='center', va='bottom', fontsize=12, rotation=45, fontweight='bold'
            )
    # 添加文字标签（对比数据）
    for j in range(len(comp_bars)):
        height = comp_bars[j].get_height()
        if height <= 1.0:
            ax.text(
                comp_bars[j].get_x() + comp_bars[j].get_width()/2, height + 0.02,
                f"{comp_counts[cat][j]}",
                ha='center', va='bottom', fontsize=12, rotation=45, fontweight='bold'
            )



# 图例、坐标轴、网格等美化
ax.set_ylim(0, 1.2)
ax.set_ylabel("Percentage", fontsize=16, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=16, fontweight='bold')
ax.set_yticks(np.linspace(0, 1.0, 11))
ax.set_yticklabels([f"{int(t * 100)}%" for t in np.linspace(0, 1.0, 11)])

ax.legend(title="Category", fontsize=10, loc=(2))
ax.grid(True, axis='y', linestyle='--', alpha=0.5, zorder=0)

plt.title(f"EIS Data Summary")
plt.tight_layout()
plt.show()
