# Import

In [10]:
from HETSFileHelper import gatherCSV, readChannel, EIS_recal_ver02
from Outlier import OutlierDetection
from EISGPR import Interpolation

import os
import re
import sys
from loguru import logger

import matplotlib.pyplot as plt 

from datetime import datetime

import numpy as np
import torch

# Filesys

In [11]:
def SearchELE(rootPath, ele_pattern = re.compile(r"(.+?)_归档")):
    '''==================================================
        Search all electrode directories in the rootPath
        Parameter: 
            rootPath: current search path
            ele_pattern: electrode dir name patten
        Returen:
            ele_list: list of electrode directories
        ==================================================
    '''
    ele_list = []
    for i in os.listdir(rootPath):
        match_ele = ele_pattern.match(i)
        if match_ele:
            ele_list.append([os.path.join(rootPath, i),match_ele.group(1)])
    return ele_list


    

In [12]:
def setup_logger(log_dir="./LOG", log_filename="file.log", file_level="WARNING", console_level="WARNING"):
    # 创建目录
    os.makedirs(log_dir, exist_ok=True)
    log_fd = os.path.join(log_dir, log_filename)

    logger.remove()
    # 如果已有日志文件，重命名添加时间戳
    if os.path.exists(log_fd):
        name, ext = os.path.splitext(log_filename)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        archived_name = f"{name}_{timestamp}{ext}"
        archived_path = os.path.join(log_dir, archived_name)
        os.rename(log_fd, archived_path)

    # 添加终端输出
    logger.add(sys.stdout, level=console_level, enqueue=True)

    # 添加文件输出
    logger.add(log_fd, level=file_level, encoding="utf-8", enqueue=True)

    return logger

# Run

In [14]:
setup_logger()
# logger.remove()
# logger.add(sys.stdout, level="WARNING")
# logger.add("./LOG/file.log", rotation="10 MB", level="INFO")

<loguru.logger handlers=[(id=5, level=30, sink=stdout), (id=6, level=30, sink='./LOG\file.log')]>



In [None]:
rootPath = "D:/Baihm/EISNN/Archive/"
ele_list = SearchELE(rootPath, ele_pattern=re.compile(r"(.+?)_归档"))
n_ele = len(ele_list)
logger.warning(f"Search in {rootPath} and find {n_ele:03d} electrodes")

## Each Electrode

In [None]:
freq_list = np.linspace(0,5000-1,101,dtype=int, endpoint=True)
freq_list_DTW = np.linspace(1000,5000-1,101,dtype=int, endpoint=True)

MODEL_SUFFIX = "Matern12_Ver01"
SAVE_FLAG = False

# for i in range(n_ele):
for i in range(0,1):
    elePath = ele_list[i][0]
    ele_id = ele_list[i][1]
    logger.warning(f"ELE[{i+1}/{n_ele}]: \t{elePath}")
    

    # Storage Preparing
    save_dir = f"{elePath}/{MODEL_SUFFIX}/"
    pt_file_name = f"{ele_id}_{MODEL_SUFFIX}.pt"
    os.makedirs(save_dir, exist_ok=True)
    if os.path.exists(os.path.join(save_dir, pt_file_name)):
        logger.warning(f"FileAlreadyExistsWarning: {ele_id} - {pt_file_name} already exists.")
        if SAVE_FLAG:
            continue


    # Load EIS data
    EISDict = gatherCSV(elePath)
    n_day   = len(EISDict)
    if n_day < 3:
        logger.warning(f"IllegalInputError: {ele_id} only has {n_day} samples.")
        continue
    try:
        x_day_full = [datetime.strptime(date, '%Y%m%d') for date in EISDict.keys()]
    except Exception as e:
        logger.error(f"IllegalDateError: {ele_id} has wrong date format. Please check the saving file. Error Code: {e}")
        continue

    _key    = next(iter(EISDict))
    n_ch    = len(EISDict[_key])
    
    if n_ch != 128:
        logger.warning(f"ChannelNumberWarning: {ele_id} only has {n_ch} channels.")
        continue



    # Iteration for each channel
    data_group = {}
    data_group['Channels']    = []
    # for j in range(n_ch):
    for j in range(0,1):
        try:
            # logger.warning(f"ELE[{i+1}/{n_ele}] - ch[{j+1}/{n_ch}]")
            # logger.info(f"{EISDict}")
            chData = readChannel(j, EISDict)
            chData_DTW = chData[:,:,freq_list_DTW]
            # Outlier Detection
            eis_seq, eis_cluster, eis_anomaly, leaf_anomaly = OutlierDetection.OutlierDetection(chData_DTW)
            if np.shape(eis_seq)[0] < 3:
                logger.warning(f"OutlierDetectionWarning: {ele_id} - CHID[{j}] only has {np.shape(eis_seq)[0]} valid samples.")
                continue


            # Interpolation
            phz_calibration = np.loadtxt("./EISGPR/phz_Calib.txt")
            for k in range(np.shape(chData)[0]):
                ch_eis = EIS_recal_ver02(chData[k,:,:], phz_calibration)
                chData[k,:,:] = ch_eis

            chData = chData[:,:,freq_list]
            if np.isnan(chData).any():
                logger.warning(f"OutlierDetectionWarning: {ele_id} - CHID[{j}] chData Invalid")
                continue


        
            x_train_full, y_train_full, x_eval_full, y_eval_full, y_eval_err_full, eis_cluster_eval = \
                Interpolation.PiecewiseGPR(x_day_full, chData, eis_seq, eis_cluster, SPEED_RATE = 2, training_iter = 200, lr = 0.05)

            # Plot
            fig = plt.figure(figsize=(16, 9), constrained_layout=True)
            Interpolation.EISPreprocessPlot(fig, chData, x_train_full, y_train_full, x_eval_full, y_eval_full, y_eval_err_full, eis_seq, eis_cluster, eis_anomaly)
                
            axis = fig.add_subplot(3,4,12)
            axis.axis('off')
            font_properties = {
                'family': 'monospace',  # 固定宽度字体
                'size': 14,             # 字体大小
                'weight': 'bold'        # 加粗
            }

            text = f"EIE  : {ele_id}\nCHID : {j:03d}\nFrom : {x_day_full[0].strftime('%Y-%m-%d')}\nTo   : {x_day_full[-1].strftime('%Y-%m-%d')}"
            axis.text(0.2, 0.5, text, fontdict = font_properties, ha='left', va='center')

            # Save Fig
            fig_name = f"EISGPR_{ele_id}_ch{j:03d}.png"
            
            os.makedirs(save_dir, exist_ok=True) 
            path = os.path.join(save_dir, fig_name)

            fig.savefig(path)
            plt.close(fig) 

            # Data Saving
            channel_group = {}
            channel_group['x_train']    = x_train_full
            channel_group['y_train']    = y_train_full
            channel_group['x_eval']     = x_eval_full
            channel_group['y_eval']     = y_eval_full
            channel_group['y_eval_err'] = y_eval_err_full
            channel_group['eis_cluster_eval'] = eis_cluster_eval

            data_group[f"ch_{j:03d}"] = channel_group
            data_group['Channels'].append(f"ch_{j:03d}")
        except Exception as e:
            logger.warning(f"ELE[{i+1}/{n_ele}] - ch[{j+1}/{n_ch}] Run with error: {e}")
            continue

        # Storage Preparing
    pt_store = {}
    meta_group = {}
    meta_group["ele_id"]    = ele_id
    meta_group["elePath"]   = elePath
    meta_group["TimeSpan"]  = x_day_full
    meta_group["n_day"]     = n_day
    meta_group["n_ch"]      = n_ch
    meta_group["Model"]     = MODEL_SUFFIX
    meta_group["Creater"]   = "Ming"
    meta_group['Date']      = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    pt_store['meta_group'] = meta_group
    pt_store['data_group'] = data_group
    if SAVE_FLAG:
        torch.save(pt_store, os.path.join(save_dir, pt_file_name))




# Load Test

In [None]:

if False:
    pt_name = "D:\Baihm\EISNN\Archive/01037160_归档\Matern12_Ver01/01037160_Matern12_Ver01.pt"
    loaded = torch.load(pt_name)

# Fix x_train

In [None]:


# MODEL_SUFFIX = "Matern12_Ver01"

# all_data_list = []

# for i in range(n_ele):
# # for i in range(3):
#     fd_pt = os.path.join(ele_list[i][0], MODEL_SUFFIX, f"{ele_list[i][1]}_{MODEL_SUFFIX}.pt")
#     if not os.path.exists(fd_pt):
#         # logger.warning(f"{fd_pt} does not exist")
#         continue
#     data_pt = torch.load(fd_pt, weights_only=False)
#     _meta_group = data_pt["meta_group"]
#     _data_group = data_pt["data_group"]

#     n_day       = _meta_group["n_day"]
#     n_ch        = _meta_group["n_ch"]
#     n_valid_ch  = len(_data_group["Channels"])