# Note
本文档基于分布计算得到的outlier detection数据库进行分段GPR

# Import

In [None]:
# from Outlier import OutlierDetection
import Interpolation

import os
import re
import sys
from loguru import logger

import matplotlib.pyplot as plt 

from datetime import datetime

import numpy as np
import torch

# Filesys

In [None]:
def SearchELE(rootPath, ele_pattern = re.compile(r"(.+?)_归档")):
    '''==================================================
        Search all electrode directories in the rootPath
        Parameter: 
            rootPath: current search path
            ele_pattern: electrode dir name patten
        Returen:
            ele_list: list of electrode directories
        ==================================================
    '''
    ele_list = []
    for i in os.listdir(rootPath):
        _path = os.path.join(rootPath, i)
        if os.path.isdir(_path):
            match_ele = ele_pattern.match(i)
            if match_ele:
                ele_list.append([_path, match_ele.group(1)])
            else:
                ele_list.extend(SearchELE(_path, ele_pattern))

    return ele_list

In [None]:
def setup_logger(log_dir="./LOG", log_filename="file.log", file_level="WARNING", console_level="WARNING"):
    # 创建目录
    os.makedirs(log_dir, exist_ok=True)
    log_fd = os.path.join(log_dir, log_filename)

    logger.remove()
    # 如果已有日志文件，重命名添加时间戳
    if os.path.exists(log_fd):
        name, ext = os.path.splitext(log_filename)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        archived_name = f"{name}_{timestamp}{ext}"
        archived_path = os.path.join(log_dir, archived_name)
        os.rename(log_fd, archived_path)

    # 添加终端输出
    logger.add(sys.stdout, level=console_level, enqueue=True)

    # 添加文件输出
    logger.add(log_fd, level=file_level, encoding="utf-8", enqueue=True)

    return logger

# Run

In [None]:
if True:
    setup_logger(log_dir="D:\Baihm\EISNN\LOG\GPR_outlier_Ver04")

# logger.remove()
# logger.add(sys.stdout, level="WARNING")
# logger.add("./LOG/file.log", rotation="10 MB", level="INFO")

In [None]:
# rootPath = "D:/Baihm/EISNN/Archive/"
# ele_list = SearchELE(rootPath)
# DATASET_SUFFIX = "Outlier_Ver03"

# rootPath = "D:/Baihm/EISNN/Archive_New/"
# ele_list = SearchELE(rootPath)
# DATASET_SUFFIX = "Outlier_Ver04"

rootPath = "D:/Baihm/EISNN/Invivo/"
ele_list = SearchELE(rootPath, re.compile(r"(.+?)_Ver02"))
DATASET_SUFFIX = "Outlier_Ver04"


n_ele = len(ele_list)
logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")

## Each Electrode

In [None]:
freq_list = np.linspace(0,5000-1,101,dtype=int, endpoint=True)

MODEL_SUFFIX = f"{DATASET_SUFFIX}_Matern12_Ver01"
SAVE_FLAG = True


In [None]:

for i in range(n_ele):
# for i in range(0,3):
    # logger.info(f"ELE Begin: {ele_list[i][0]}")
    fd_pt = os.path.join(ele_list[i][0], DATASET_SUFFIX, f"{ele_list[i][1]}_{DATASET_SUFFIX}.pt")
    if not os.path.exists(fd_pt):
        logger.warning(f"{fd_pt} does not exist")
        continue
    

    data_pt = torch.load(fd_pt)
    _meta_group = data_pt["meta_group"]
    _data_group = data_pt["data_group"]


    ele_id  = _meta_group["ele_id"]
    elePath = _meta_group["elePath"]
    n_ch = _meta_group["n_ch"]      
    x_day_full = _meta_group["TimeSpan"]


    logger.warning(f"ELE[{i+1}/{n_ele}]: \t{ele_id} - {elePath}")


    # Storage path
    save_dir = f"{elePath}/{MODEL_SUFFIX}/"
    pt_file_name = f"{ele_id}_{MODEL_SUFFIX}.pt"
    os.makedirs(save_dir, exist_ok=True)
    if os.path.exists(os.path.join(save_dir, pt_file_name)):
        logger.warning(f"FileAlreadyExistsWarning: {ele_id} - {pt_file_name} already exists.")
        if SAVE_FLAG:
            continue


    for j in _data_group['Channels']:
        try:
            logger.info(f"ELE[{ele_id}] - ch[{j}] Begin") 
            channel_group_raw = _data_group[j]

            
            chData      = channel_group_raw['chData']         
            eis_seq     = channel_group_raw['eis_seq']        
            eis_cluster = channel_group_raw['eis_cluster']    
            eis_anomaly = channel_group_raw['eis_anomaly']    

            if chData.shape[2] == 5000:
                chData = chData[:, :, freq_list]


            # Interpolation
            x_train_full, y_train_full, x_eval_full, y_eval_full, y_eval_err_full, eis_cluster_eval = \
                Interpolation.PiecewiseGPR(x_day_full, chData, eis_seq, eis_cluster, SPEED_RATE = 2, training_iter = 200, lr = 0.05)

            logger.info(f"ELE[{ele_id}] - ch[{j}] Interpolation Finished")

            # Plot
            fig = plt.figure(figsize=(16, 9), constrained_layout=True)
            Interpolation.EISPreprocessPlot(fig, chData, x_train_full, y_train_full, x_eval_full, y_eval_full, y_eval_err_full, eis_seq, eis_cluster, eis_anomaly)
                
            axis = fig.add_subplot(3,4,12)
            axis.axis('off')
            font_properties = {
                'family': 'monospace',  # 固定宽度字体
                'size': 14,             # 字体大小
                'weight': 'bold'        # 加粗
            }

            text = f"EIE  : {ele_id}\nCHID : {j:03d}\nFrom : {x_day_full[0].strftime('%Y-%m-%d')}\nTo   : {x_day_full[-1].strftime('%Y-%m-%d')}"
            axis.text(0.2, 0.5, text, fontdict = font_properties, ha='left', va='center')

            # Save Fig
            fig_name = f"EISGPR_{ele_id}_ch{j:03d}.png"
            
            os.makedirs(save_dir, exist_ok=True) 
            path = os.path.join(save_dir, fig_name)

            fig.savefig(path)
            plt.close(fig) 

            # Data Saving
            channel_group_intp = {}
            channel_group_intp['chData_intp_mean']  = y_eval_full
            channel_group_intp['chData_intp_var']   = y_eval_err_full
            channel_group_intp['x_train']           = x_train_full
            channel_group_intp['x_eval']            = x_eval_full
            channel_group_intp['x_eval_cluster']    = eis_cluster_eval

            _data_group[j] = channel_group_intp
            logger.info(f"ELE[{ele_id}] - ch[{j}] Finished")
            
        except Exception as e:
            logger.warning(f"ELE[{ele_id}] - ch[{j}] Run with error: {e}")
            continue

    
    pt_store = {}
    pt_store["meta_group"] = _meta_group
    pt_store["data_group"] = _data_group
    if SAVE_FLAG:
        torch.save(pt_store, os.path.join(save_dir, pt_file_name))




# Load Test

In [None]:

if False:
    pt_name = "D:\Baihm\EISNN\Archive/01037160_归档\Matern12_Ver01/01037160_Matern12_Ver01.pt"
    loaded = torch.load(pt_name)

# Fix x_train

In [None]:


# MODEL_SUFFIX = "Matern12_Ver01"

# all_data_list = []

# for i in range(n_ele):
# # for i in range(3):
#     fd_pt = os.path.join(ele_list[i][0], MODEL_SUFFIX, f"{ele_list[i][1]}_{MODEL_SUFFIX}.pt")
#     if not os.path.exists(fd_pt):
#         # logger.warning(f"{fd_pt} does not exist")
#         continue
#     data_pt = torch.load(fd_pt, weights_only=False)
#     _meta_group = data_pt["meta_group"]
#     _data_group = data_pt["data_group"]

#     n_day       = _meta_group["n_day"]
#     n_ch        = _meta_group["n_ch"]
#     n_valid_ch  = len(_data_group["Channels"])