In [None]:
# %%
import pprint
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LinearRegression
from xgboost import XGBRegressor
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from category_encoders import TargetEncoder
from tsfresh import extract_features
# ComprehensiveFCParameters
from tsfresh.feature_extraction import MinimalFCParameters
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from itertools import product
from typing import List
import pymssql
import joblib
import os
import warnings
import base64
import traceback
# %autosave 10

import yaml
import os
import logging
from log.logging_setting import Log
from util import cfg

Log().create_logger('NP_Txn.log')
logger = logging.getLogger('NP_Txn.log')

# from IPython import embed
# %%
with open(os.path.join(os.path.dirname(__file__), 'config.yaml'), encoding='utf-8') as f:
    config = yaml.safe_load(f)

snapshot_folder = os.path.join(os.path.dirname(__file__), 'snapshot_file_txn')
# %%
if not os.path.isdir(snapshot_folder):
    os.mkdir(snapshot_folder)

# sys.path.append('..')
warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", pd.errors.DtypeWarning)

# %%


class NonDataError(Exception):
    pass


class Nonnextweek(Exception):
    pass


def connect_db():
    return pymssql.connect(
        server=cfg['DB_ETL']['S'], 
        database=cfg['DB_ETL']['N'])

class Dataset:
    def __init__(self, tgt_date, scenario_name, group, level, bankL=True, is_Train=True):
        logger.info(f'------------Dataset--------------  \
                    \n scenario name: {scenario_name}   \
                    \n tgt_date: {tgt_date}\
                    \n group: {group}\
                    \n is_Train: {is_Train}\
                    \n---------------------------------')
        assert group in ["Person", "Corp"]
        assert level in ["PTY", "ACC"]

        self.tgt_date = tgt_date  # 計算基期
        self.scenario_name = scenario_name
        self.group = group
        self.level = level
        self.is_Train = is_Train
        self.n_days = 7
        self.rd_ndays = f"資料日期-{self.n_days}日"
        self.cust_lists = {}
        self.entity_level = "客戶編碼" if level == "PTY" else "帳戶編碼"

        self._tgt_group_info = None
        self.vip_list = None
        self.train_start_date = None
        self.train_end_date = self.tgt_date - timedelta(days=365)  # 訓練資料結束日
        alert, txn, risk, info, ac = self._get_data_from_sql(
        ) if bankL else self._get_bankT_data()  # 從DB讀取資料

        self._alert_processed = (
            alert
            .drop(["警示編號", "監控層級", "觸發說明"], axis=1)
            .assign(**alert.資料日期.dt.isocalendar(),
                    month=alert.資料日期.dt.month,
                    quarter=alert.資料日期.dt.quarter,
                    hfy=alert.資料日期.dt.quarter.isin([3, 4])+1,
                    bigMon=alert.資料日期.dt.month.isin([1, 3, 5, 7, 8, 10, 12]).astype(int))  # 將alert整理出更多的特徵資料
        )
        self.txn_c = txn[txn['tran_type'] == 'DEBIT'].drop("tran_type", axis=1)
        self.txn_d = txn[txn['tran_type'] ==
                         'CREDIT'].drop("tran_type", axis=1)
        self.txn_t = txn[txn['tran_type'] == 'TXN'].drop("tran_type", axis=1)

        self._txn_processed = txn  # _get_data_from_sql整理出的txn資料
        self._risk_processed = risk  # _get_data_from_sql整理出的risk資料
        self._info_processed = info  # _get_data_from_sql整理出的info資料
        self.ac = ac  # _get_data_from_sql從VW_NP_FSC_PARTY_ACCOUNT_BRIDGE整理出的客帳戶對照

        self.train = self._get_data("train")  # VIP訓練資料
        self.val = self._get_data("val")  # VIP驗證資料
        self.test = self._get_data("test")  # VIP當天資料

        self.seg_mapping = self._get_cust_seg_def()

    def num_sam_dist(self, key: List[str], on: str, cust_seg: str = ""):
        cust_list = self.cust_seg_cust_list(cust_seg) if cust_seg else None
        if on == "all":
            return (
                self._alert_processed
                .pipe(lambda x, cust_list=cust_list: x.query("客戶編碼 in @cust_list") if cust_seg else x)
                .groupby(key)
                .agg({"資料日期": "count"})
                .rename({"資料日期": "警示數量"}, axis=1)
                .reset_index()
            )
        elif on == "vip":
            return (
                self._alert_processed
                .pipe(lambda x, cust_list=cust_list: x.query("客戶編碼 in @cust_list") if cust_seg else x)
                .query(f"{self.entity_level} in @self.vip_list")
                .groupby(key)
                .agg({"資料日期": "count"})
                .rename({"資料日期": "警示數量"}, axis=1)
                .reset_index()
            )
        else:
            raise ValueError("Unknown Target data")

    def num_sam_dist2(self, key: List[str], on: str, cust_seg: str = ""):
        if on == "all":
            return (
                self._alert_processed
                .pipe(lambda x, seg_list=self.seg_mapping[cust_seg]: x.query("cust_segmentation in @seg_list") if cust_seg else x)
                .groupby(key)
                .agg({"資料日期": "count"})
                .rename({"資料日期": "警示數量"}, axis=1)
                .reset_index()
            )
        elif on == "vip":
            return (
                self._alert_processed
                .pipe(lambda x, seg_list=self.seg_mapping[cust_seg]: x.query("cust_segmentation in @seg_list") if cust_seg else x)
                .query(f"{self.entity_level} in @self.vip_list")
                .groupby(key)
                .agg({"資料日期": "count"})
                .rename({"資料日期": "警示數量"}, axis=1)
                .reset_index()
            )
        else:
            raise ValueError("Unknown Target data")

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    # ------------------------ 取 bankT 資料 ------------------------
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

    # 假設 alert 裡只有一個態樣
    # 假設 alert 裡的客戶編碼只有自然人 or 只有法人
    # alert 裡的客戶編碼 = cust_info 裡的客戶編碼
    # tgt_date 請填入 pd.to_datetime("2020-12-27"), 不要填入字串日期

    # def _get_bankT_data(self):
    #     if self.level == "PTY":
    #         alert = (
    #             pd.read_csv('D:/Data_/A_cn.csv', index_col=0,
    #                         parse_dates=["資料日期"])
    #             .query("監控層級 == 'PTY'")
    #             .drop("監控客(帳)戶號碼", axis=1)
    #         )
    #     else:
    #         alert = (
    #             pd.read_csv('D:/Data_/A_cn.csv', index_col=0,
    #                         parse_dates=["資料日期"])
    #             .query("監控層級 == 'ACC'")
    #             .rename({"監控客(帳)戶號碼": "帳戶編碼"}, axis=1)
    #         )

    #     cust_info = pd.read_csv(
    #         'D:/Data_/mCp.csv', parse_dates=["資料開始週期"], index_col=0)
    #     self._tgt_group_info = cust_info[['客戶編碼', '客戶分群', '資料開始週期']]

    #     # ------------------------ 加入客戶分群 ------------------------
    #     alert = self._alert_add_cust_seg(alert)
    #     # ------------------------ 新的計算 VIP 方式 ------------------------
    #     self._alert_processed = alert
    #     # self.train_start_date = max(
    #     #     self.tgt_date - timedelta(days=365*2), alert.資料日期.min())  # 訓練資料起始日
    #     self.train_start_date = self.tgt_date - timedelta(days=365*2)

    #     self.vip_list = self._get_vip_list()

    #     if self.level == "PTY":
    #         ac = None
    #         cust_info = cust_info.query("客戶編碼 in @self.vip_list")
    #     else:
    #         ac = (
    #             pd.read_csv("D:/from210/210to230/NCTU/share/210/Data/AC.csv",
    #                         sep='\t', usecols=['Acct_No', 'Cust_No', 'Change_Begin_Dt'], parse_dates=['Change_Begin_Dt'])
    #             .rename({'Acct_No': '帳戶編碼',
    #                      'Cust_No': '客戶編碼',
    #                      'Change_Begin_Dt': '資料開始週期'}, axis=1)
    #             .query("帳戶編碼 in @self.vip_list")
    #             .sort_values("資料開始週期")
    #         )
    #         cust_info = cust_info.query("客戶編碼 in @ac.客戶編碼.unique()")

    #     txn = (
    #         pd.read_csv('D:/Data_/Txn_cn.csv', parse_dates=["交易日期"], usecols=[
    #                     "交易日期", '折台幣交易金額', "客戶編碼", "交易類型識別碼", "帳戶號碼"])
    #         .rename({"帳戶號碼": "帳戶編碼"}, axis=1)
    #         .query(f"{self.entity_level} in @self.vip_list")
    #     )

    #     return alert, txn, cust_info, cust_info, ac

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    # ------------------------ 取 bankT 資料 ------------------------
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    # ------------------------ 取 bankL 資料 ------------------------
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

    def _get_data_from_sql(self):

        def res_to_df(res, cur_des):
            return pd.DataFrame(res, columns=[cur_des[i][0] for i in range(len(cur_des))])
        f = open('test.txt', 'w', encoding='utf-8')
        conn = connect_db()

        cur = conn.cursor()
        # //------------------------ 讀入客戶警示檔 ------------------------
        # cur.execute("SELECT * FROM " +
        #             f"VW_NP_FSK_ALERT_{self.scenario_name[-5:-2] + '_' + self.scenario_name[-2:]}")  # 從對應樣態的alert table中找出所有的alert(不只訓練資料而是權範圍)為何沒有限制?

        # f"VW_NP_FSK_ALERT_{self.scenario_name[-5:-2] + '_' + self.scenario_name[-2:]}")
        if self.level == "PTY":

            cur.execute("SELECT alert_id\
                    , primary_entity_level_code\
                    , primary_entity_number\
                    , scenario_name\
                    , actual_values_text\
                    , run_date\
                    , cust_segmentation\
                    , Cust_No as party_number FROM " +
                        f"VW_FB_Alert_{self.group} with (nolock) where scenario_name='TWN_{self.scenario_name[-5:-2] + '_' + self.scenario_name[-2:]}'")

            alert = (
                res_to_df(cur.fetchall(), cur.description)
                .query("primary_entity_level_code == 'PTY'")
                .drop("primary_entity_number", axis=1)
                .rename({"party_number": "客戶編碼"}, axis=1)
                .rename({"party_number": "客戶編碼"}, axis=1)
            )
            alert["客戶編碼"] = alert["客戶編碼"].apply(
                lambda x: x.replace(' ', ''))

        else:
            cur.execute("SELECT alert_id\
                    , primary_entity_level_code\
                    , primary_entity_number as Acct_Num\
                    , scenario_name\
                    , actual_values_text\
                    , run_date\
                    , cust_segmentation\
                    , Cust_No as party_number FROM " +
                        f"VW_FB_Alert_{self.group} with (nolock) where scenario_name='TWN_{self.scenario_name[-5:-2] + '_' + self.scenario_name[-2:]}'")
            alert = (
                res_to_df(cur.fetchall(), cur.description)
                .query("primary_entity_level_code == 'ACC'")
                .rename({"Acct_Num": "帳戶編碼",
                         "party_number": "客戶編碼"}, axis=1)
            )
            alert["客戶編碼"] = alert["客戶編碼"].apply(
                lambda x: x.replace(' ', ''))
            alert["帳戶編碼"] = alert["帳戶編碼"].apply(
                lambda x: x.replace(' ', ''))
        if alert.shape[0] == 0:
            raise NonDataError('No alert')  # 沒有alert直接預測都是0合理
        logger.info(f'alert part number {alert.客戶編碼.unique()}')
        # logger.info(
        #     f'PB081001018795 in alert {"PB081001018795" in alert.帳戶編碼.values.tolist()}')
        # logger.info(
        # f'*2-5.059 in alert {"*2-5.059" in alert.客戶編碼.values.tolist()}')
        alert.to_csv(os.path.join(snapshot_folder, 'alert1.csv'))
        f.write('alert \n')
        f.write('  '.join(alert.客戶編碼.unique()))
        # ------------------------ 讀入客戶警示檔 ------------------------//
        # from IPython import embed
        # embed()
        # //------------------------ 抓取客戶資料------------------------
        #
        cur.execute("SELECT [party_number], [change_begin_date], [customer_segmentation] FROM " + f"VW_NP_FSC_PARTY_DIM_{self.group} with (nolock)" +
                    " where party_number in %s" % str(tuple(alert.客戶編碼.unique())))   # 從alert找出對應的客戶編碼並在party dim找出對應的客戶資料

        self._tgt_group_info = (
            res_to_df(cur.fetchall(), cur.description)
            .rename({"change_begin_date": "資料開始週期",
                     "party_number": "客戶編碼",
                     "customer_segmentation": "客戶分群"}, axis=1)  # cust_seg_cust_list來源
        )

        logger.info(
            f'_tgt_group_info part number {self._tgt_group_info.客戶編碼.unique()}')
        # logger.info(
        # f'*2-5.059 in alert {"*2-5.059" in self._tgt_group_info.客戶編碼.values.tolist()}')
        self._tgt_group_info.to_csv(os.path.join(
            snapshot_folder, 'tgt_group_info.csv'))
        f.write('\n_tgt_group_info \n')
        f.write('  '.join(self._tgt_group_info.客戶編碼.unique()))
        if alert.shape[0] == 0:
            raise NonDataError('No alert')  # 這部分找不到對應資料可能要再確認問題
        # ------------------------ 抓取客戶資料------------------------//

        # //-----------------------保留alert內客戶資料有在客戶表內的alert-------------------------
        alert = (
            alert
            .rename({"run_date": "資料日期",
                     "alert_id": "警示編號",
                     "primary_entity_level_code": "監控層級",
                     "actual_values_text": "觸發說明"}, axis=1)
            # 過濾出自然人/法人客戶產生的警示
            .query("客戶編碼 in @self._tgt_group_info.客戶編碼.unique()")
        )

        logger.info(
            f'alert part number {alert.客戶編碼.unique()}')
        # logger.info(
        #     f'PB081001018795 in alert {"PB081001018795" in alert.帳戶編碼.values.tolist()}')
        # logger.info(
        # f'*2-5.059 in alert {"*2-5.059" in alert.客戶編碼.values.tolist()}')
        f.write('\nalert \n')
        f.write('  '.join(alert.客戶編碼.unique()))
        # ------------------------ 加入客戶分群 ------------------------
        alert = self._alert_add_cust_seg(alert)  # 將alert mapping 客戶表標註alert的
        # ------------------------ 新的計算 VIP 方式 ------------------------
        self._alert_processed = alert
        logger.info(
            f'_alert_processed part number {self._alert_processed.客戶編碼.unique()}')
        # logger.info(
        #     f'PB081001018795 in alert {"PB081001018795" in self._alert_processed.帳戶編碼.values.tolist()}')
        # logger.info(
        # f'*2-5.059 in alert {"*2-5.059" in self._alert_processed.客戶編碼.values.tolist()}')
        f.write('\n _alert_processed \n')
        f.write('  '.join(self._alert_processed.客戶編碼.unique()))
        f.close
        # if self.is_Train:
        #     self.train_start_date = max(
        #         self.tgt_date - timedelta(days=365*2), alert.資料日期.min())  # 訓練資料起始日
        # else:
        self.train_start_date = self.tgt_date - timedelta(days=365*2)
        self._alert_processed.to_csv(os.path.join(
            snapshot_folder, 'alert_proceessed.csv'))
        self.vip_list = self._get_vip_list()
        logger.info(f'_get_vip_list  {self.vip_list}')

        if self.level == "PTY":
            print('PTY')
            ac = None
            # 使用VIP清單回頭查詢客戶資料
            cur.execute(
                "SELECT * FROM " + f"VW_NP_FSC_PARTY_DIM_{self.group} with (nolock)" + " where party_number in %s" % str(tuple(self.vip_list)))
            cust_info = (
                res_to_df(cur.fetchall(), cur.description)
                .rename({"change_begin_date": "資料開始週期",
                        "party_number": "客戶編碼",
                         "customer_segmentation": "客戶分群"}, axis=1)
                .drop("change_end_date", axis=1)
            )
            # 使用VIP回頭查詢交易明細
            cur.execute("SELECT [Posted_Date_Key], [Cust_Num], [Ccy_Amt] ,[transaction_cdi_desc]\
                        FROM " + f"VW_NP_FB_Scenario_{self.scenario_name} with (nolock)" + " where Cust_Num in %s" % str(tuple(self.vip_list)))
            txn = (
                res_to_df(cur.fetchall(), cur.description)
                .rename({"Posted_Date_Key": "交易日期",
                        "Cust_Num": "客戶編碼",
                         "Ccy_Amt": "折台幣交易金額",
                         "transaction_cdi_desc": "tran_type"}, axis=1)
                .astype({"交易日期": str})
                .pipe(lambda x: x.assign(交易日期=pd.to_datetime(x.交易日期)))
            )

            # 使用VIP回頭查詢客戶風險檔

            query = "SELECT * FROM " + \
                f"VW_NP_FSC_PARTY_RISK_FACTOR_DIM_SCD_{self.group} with (nolock)" + \
                    " where party_number in %s" % str(tuple(self.vip_list))
            logger.info(
                f'{self.tgt_date} vip get risk table \n query\n {query}')
            cur.execute(query)
            risk = (
                res_to_df(cur.fetchall(), cur.description)
                .rename({"change_begin_date": "資料開始週期",
                         "party_number": "客戶編碼"}, axis=1)
                .drop("change_end_date", axis=1)
                .pipe(lambda x: x.assign(資料開始週期=pd.to_datetime(x.資料開始週期)))
            )
            if risk.shape[0] == 0:
                raise NonDataError('no risk data')

        else:
            print('ACC')
            # 使用vip回頭查克帳戶對照
            cur.execute("SELECT [account_number], [party_number], [change_begin_date] FROM VW_NP_FSC_PARTY_ACCOUNT_BRIDGE with (nolock)" +
                        " where account_number in %s" % str(tuple(self.vip_list)))
            ac = (
                res_to_df(cur.fetchall(), cur.description)
                .rename({"change_begin_date": "資料開始週期",
                        "party_number": "客戶編碼",
                         "account_number": "帳戶編碼"}, axis=1)
            )

            if ac.shape[0] == 0:
                raise NonDataError('no ac')
            # 使用查到的客帳戶對照找到客戶資料
            cur.execute("SELECT * FROM " + f"VW_NP_FSC_PARTY_DIM_{self.group} with (nolock)" +
                        " where party_number in %s" % str(tuple(ac.客戶編碼.unique())))
            cust_info = (
                res_to_df(cur.fetchall(), cur.description)
                .rename({"change_begin_date": "資料開始週期",
                        "party_number": "客戶編碼",
                         "customer_segmentation": "客戶分群"}, axis=1)
                .drop("change_end_date", axis=1)
            )
            # 使用vip回頭找交易明細
            cur.execute("SELECT [Posted_Date_Key], [Acct_Num], [Ccy_Amt],[transaction_cdi_desc]\
                        FROM " + f"VW_NP_FB_Scenario_{self.scenario_name}" + " with (nolock) where Acct_Num in %s" % str(tuple(self.vip_list)))
            txn = (
                res_to_df(cur.fetchall(), cur.description)
                .rename({"Posted_Date_Key": "交易日期",
                        "Cust_Num": "客戶編碼",
                         "Acct_Num": "帳戶編碼",
                         "Ccy_Amt": "折台幣交易金額",
                         "transaction_cdi_desc": "tran_type"}, axis=1)
                .astype({"交易日期": str})
                .pipe(lambda x: x.assign(交易日期=pd.to_datetime(x.交易日期)))
            )
            # 使用VIP回頭找客戶風險
            cur.execute(
                "SELECT * FROM " + f"VW_NP_FSC_PARTY_RISK_FACTOR_DIM_SCD_{self.group}" + " with (nolock) where party_number in %s" % str(tuple(ac.客戶編碼.unique())))
            risk = (
                res_to_df(cur.fetchall(), cur.description)
                .rename({"change_begin_date": "資料開始週期",
                         "party_number": "客戶編碼"}, axis=1)
                .drop("change_end_date", axis=1)
                .pipe(lambda x: x.assign(資料開始週期=pd.to_datetime(x.資料開始週期)))
            )
            if risk.shape[0] == 0:
                raise NonDataError('no risk data')
        cur.close()
        conn.close()

        return alert, txn, risk, cust_info, ac

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    # ------------------------ 取 bankL 資料 ------------------------
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

    def _get_cust_seg_def(self):
        #         def_mapping = {
        #             '個人高風險': ['個人高風險'],
        #             '個人非高風險': ['個人非高風險'],
        #         }
        if self.group == "Person":
            def_mapping = {
                'PEH': ['PE', 'PH'],
                'PML': ['PM', 'PL', 'P'],
            }
        else:
            def_mapping = {
                'CEH1': ['CE1', 'CH1'],
                'CML1': ['CM1', 'CL1'],
                'CEH2': ['CE2', 'CH2'],
                'CML2': ['CM2', 'CL2'],
                'CEH3': ['CE3', 'CH3'],
                'CML3': ['CM3', 'CL3'],
                'C4': ['CE4', 'CH4', 'CM4', 'CL4']
            }

        def_mapping.update(
            {"all": [item for l in def_mapping.values() for item in l]})

        return def_mapping

    def _alert_add_cust_seg(self, alert):
        '''
        使用該客戶資料中最新的客戶風險分群作為該客戶所有資料的客戶風險分群
        '''
        #         return (
        #             alert
        #             .set_index("警示編號")
        #             .pipe(lambda x: x.assign(客戶分群=(
        #                 x
        #                 .reset_index()
        #                 .merge(self._tgt_group_info, how="left")
        #                 .pipe(lambda y: y.loc[
        #                     y
        #                     .query("資料開始週期 < 資料日期")
        #                     .groupby("警示編號")["資料開始週期"]
        #                     .idxmax()])
        #                 .set_index("警示編號")
        #                 .客戶分群
        #             )))
        #             .reset_index()
        #         )

        return (
            alert
            .assign(客戶分群=(
                alert
                .客戶編碼
                .map((
                    self._tgt_group_info
                    .astype({"資料開始週期": str})
                    .set_index(["客戶編碼", "資料開始週期"])
                    .loc[(
                        self._tgt_group_info
                        .groupby("客戶編碼")
                        .agg({"資料開始週期": "max"})
                        .astype({"資料開始週期": str})
                        .to_records()
                        .tolist()
                    )]
                    .reset_index()
                )
                    .pipe(lambda x: dict(zip(x.客戶編碼, x.客戶分群)))
                )))
        )

    def _get_vip_list(self):
        """
        取得VIP清單(Description)
        如果是is_Train=True用時，資料會參考self.train_start_date 和 self.train_end_date
        反之則使用_alert_processed所有的entity進行VIP計算

        Parameters
        ----------
        None

        Returns
        -------
        index_array : ndarray
            所有類群的VIP清單

        See Also
        --------
        (Description)

        Notes
        -----
        (Description)

        Examples
        --------
        >>> 
        """
        def to_valid_v(v):
            if v < 3:
                return 3
            else:
                return v
        if self.is_Train:
            logger.info(
                f'data in train, self entity level {self.entity_level}')
            return np.hstack(
                self._alert_processed.query(
                    "@self.train_start_date <= 資料日期 < @self.train_end_date")  # 用train範圍查找vip
                .groupby(["客戶分群"])  # groupby客戶分群
                .apply(lambda x: (
                    # 對每個分群又進行了基於entity_level(客戶編碼或帳戶編碼)的groupby
                    x.groupby([self.entity_level])
                    .agg({"資料日期": "count"})  # 計算每個entity底下共有多少筆
                    .rename({"資料日期": "警示數量"}, axis=1)  # 轉換欄位名稱
                    .reset_index()  # 移除groupby後的Index(注意，這邊沒有移除基於客戶分群的編碼)
                    .pipe(lambda vip_sam_dist: (  # 對個別分群的客戶，計算其數量，取前1%作為VIP，若不足3
                        vip_sam_dist
                        .nlargest(to_valid_v(int(len(vip_sam_dist) * 0.01)), "警示數量")
                        [self.entity_level]
                        .unique()
                    ))))
            )
        else:
            logger.info(
                f'data in inference, self entity level {self.entity_level}')
            return np.hstack(
                # .query("@self.train_start_date <= 資料日期 < @self.train_end_date")#用train範圍查找vip
                self._alert_processed.query("@self.train_start_date <= 資料日期")
                .groupby(["客戶分群"])
                    .apply(lambda x: (
                        x.groupby([self.entity_level])
                        .agg({"資料日期": "count"})
                        .rename({"資料日期": "警示數量"}, axis=1)
                        .reset_index()
                        .pipe(lambda vip_sam_dist: (
                            vip_sam_dist
                            .nlargest(to_valid_v(int(len(vip_sam_dist) * 0.01)), "警示數量")
                            [self.entity_level]
                            .unique()
                        ))))
            )

    def cust_seg_cust_list(self, cust_seg: str):
        if cust_seg not in self.cust_lists:
            cust_seg_list = self.seg_mapping.get(cust_seg, cust_seg)

            self.cust_lists.update({cust_seg: (
                self._tgt_group_info
                .astype({"資料開始週期": str})
                .set_index(["客戶編碼", "資料開始週期"])
                .loc[(
                    self._tgt_group_info
                    .groupby("客戶編碼")
                    .agg({"資料開始週期": "max"})
                    .astype({"資料開始週期": str})
                    .to_records()
                    .tolist()
                )]
                .reset_index()
                .query("客戶分群 in @cust_seg_list")
                .客戶編碼
                .values
            )})

        return self.cust_lists[cust_seg]

    def _get_data(self, tgt: str):
        logger.info(f'''days of t v t, {self.train_start_date}, {self.train_end_date}, 
        {self.tgt_date -timedelta(days=365)}, {self.tgt_date - timedelta(days=183)}, self.tgt_date''')
        if tgt == "train":
            df = self._create_cust_sunday_cartesian_df(self.vip_list,
                                                       self.train_start_date,
                                                       self.train_end_date)
        elif tgt == "val":
            df = self._create_cust_sunday_cartesian_df(self.vip_list,
                                                       self.tgt_date -
                                                       timedelta(days=365),
                                                       self.tgt_date - timedelta(days=183))
        elif tgt == "test":
            df = self._create_cust_sunday_cartesian_df(self.vip_list,
                                                       self.tgt_date,
                                                       self.tgt_date)
        else:
            raise ValueError("Unknown Target data")
        print('get data')
        # embed()
        return (
            df
            .assign(**self._create_txn_f(df, self.txn_c, "C"),  # 產生Credit相關特徵
                    **self._create_txn_f(df, self.txn_d, "D"),  # 產生Debit相關特徵
                    **self._create_txn_f(df, self.txn_t, "T"),  # 產生Txn相關特徵
                    **self._add_latest_cust_info(df, self._risk_processed))  # 使用者最後狀態特徵
            .pipe(lambda x: x.loc[:, ~x.columns.duplicated()])
            .assign(**self._add_latest_cust_info(df, self._info_processed))
            .pipe(lambda x: x.loc[:, ~x.columns.duplicated()])
        )

    def _create_cust_sunday_cartesian_df(self, cust, start_date, end_date):
        sampled_dates = pd.date_range(
            start_date, end_date)  # 產生一個包含從起始日到結束日的日期清單
        # 從上面的清單中保留周日的部分
        sampled_dates = sampled_dates[sampled_dates.isocalendar().day == 7]
        sam_df = (
            pd.DataFrame.from_records(
                product(cust, sampled_dates),
                columns=[self.entity_level, "資料日期"])
        )
        if self.level == "ACC":
            sam_df = (
                sam_df
                .assign(**(
                    sam_df
                    .reset_index()
                    .merge(self.ac, how="left")
                    .pipe(lambda x: (
                        x.loc[
                            x
                            .query("資料日期 > 資料開始週期")
                            .groupby(["帳戶編碼", "資料日期"])
                            ["資料開始週期"]
                            .idxmax()
                        ]
                    ))
                    .set_index('index')
                    .sort_index()
                )[["客戶編碼"]])
            )

        return sam_df

    def _create_txn_f(self, sam_df, txn_df, tran_type='C'):
        '''

        '''
        logger.info(f"tran_type {tran_type}")
        sam_grouped = (
            sam_df
            .assign(**{self.rd_ndays: (sam_df.資料日期 - timedelta(days=self.n_days))})
            .groupby(self.entity_level, sort=False)
        )  # entity_level 是acc或 pty
        sam_df.to_csv(os.path.join(snapshot_folder, 'sam.csv'))
        txn_df.to_csv(os.path.join(snapshot_folder, 'txn.csv'))
        sam0 = sam_df.assign(**{'交易日期': (sam_df.資料日期 - timedelta(days=self.n_days-1)),
                                '折台幣交易金額': 0})[[self.entity_level, '交易日期', '折台幣交易金額']]
        sam1 = sam_df.assign(**{'交易日期': (sam_df.資料日期 - timedelta(days=self.n_days-2)),
                                '折台幣交易金額': 1})[[self.entity_level, '交易日期', '折台幣交易金額']]
        txn_df = pd.concat([txn_df, sam0, sam1]).reset_index(drop=True)
        txn_df = txn_df.rename({"折台幣交易金額": f"{tran_type}折台幣交易金額"}, axis=1)
        timeseries = (
            txn_df
            .sort_values("交易日期")
            .groupby(self.entity_level, sort=False)
            # .pipe(lambda x: print(x))
            .filter(lambda x: x.name in sam_grouped.groups.keys())
            .groupby(self.entity_level, sort=False)
            .apply(lambda x: self._find_past_txn(sam_grouped.get_group(x.name), x))
        )
        logger.info(f'timeseries.index.name  is {timeseries.index.name}')
        timeseries.to_csv(os.path.join(snapshot_folder, 'timeseries.csv'))
        timeseries = timeseries.reset_index(drop=True)
        if len(timeseries):
            extraction_settings = MinimalFCParameters()
            timeseries_len_gt1 = (
                timeseries[['id', 'time', f"{tran_type}折台幣交易金額"]]
                .dropna()  # 有一些交易的交易金額為 NaN, 不使用 impute 填值
                .groupby("id")
                .filter(lambda x: len(x) > 1)
            )
            #  (series, column_id='id', column_sort='date', column_value='participants')
            X = extract_features(timeseries_len_gt1,
                                 column_id='id',
                                 column_sort='time',
                                 #  kind_to_fc_parameters={"temperature": {
                                 #      "sum_values": None, "mean": None}},
                                 #  column_value='participants',
                                 default_fc_parameters=extraction_settings,
                                 n_jobs=2)
            X.to_csv(os.path.join(snapshot_folder, 'x.csv'))
        else:
            X = pd.DataFrame(
                columns=["{tran_type}折台幣交易金額__" + f for f in MinimalFCParameters()])
            # columns=["折台幣交易金額__" + f for f in ['sum_values', "mean"]])
        return X

    def _add_latest_cust_info(self, sam_df, cust_info):
        sam_grouped = (
            sam_df
            .groupby("客戶編碼", sort=False)
        )
        tmp = (cust_info
               .sort_values("資料開始週期")
               .groupby("客戶編碼", sort=False)
               .filter(lambda x: x.name in sam_grouped.groups.keys())
               .groupby("客戶編碼", sort=False)
               .apply(lambda x: self._find_latest_cust_info(sam_grouped.get_group(x.name), x)))
        if tmp.shape[0] == 0:
            raise NonDataError('output of _add_latest_cust_info is empty')
        return (
            cust_info
            .sort_values("資料開始週期")
            .groupby("客戶編碼", sort=False)
            .filter(lambda x: x.name in sam_grouped.groups.keys())
            .groupby("客戶編碼", sort=False)
            .apply(lambda x: self._find_latest_cust_info(sam_grouped.get_group(x.name), x))
            .droplevel(0)
        )

    def _find_past_txn(self, cust_sam_df, cust_txn_df):
        """
        (Description)

        Parameters
        ----------
        cust_sam_df : pandas.DataFrame
            (Description)
        cust_txn_df : pandas.DataFrame
            (Description)


        Returns
        -------
        index_array : pandas.DataFrame
            (Description)

        See Also
        --------
        (Description)

        Notes
        -----
        (Description)

        Examples
        --------
        >>> 
        """
        ts = []
        rds_n = np.searchsorted(
            cust_txn_df["交易日期"].values, cust_sam_df[self.rd_ndays].values, side='right')
        rds = np.searchsorted(
            cust_txn_df["交易日期"].values, cust_sam_df["資料日期"].values, side='right')
        for i, r in enumerate(zip(rds_n, rds)):
            if r[1] - r[0]:
                t = cust_txn_df.iloc[slice(*r)].copy()
                a = cust_sam_df.iloc[i]
                t["id"] = a.name
                date_range = pd.date_range(
                    a[self.rd_ndays], a["資料日期"], freq='H').values
                t["time"] = np.searchsorted(date_range, t["交易日期"].values)
                ts.append(t)
        if ts:
            return pd.concat(ts, axis=0)
        else:
            return None

    def _find_latest_cust_info(self, cust_sam_df, cust_info):
        idx = np.searchsorted(
            cust_info["資料開始週期"].values, cust_sam_df["資料日期"].values, side='right') - 1

        cust_info = (
            cust_info
            .iloc[idx]
            .set_index(cust_sam_df.index)
        )
        cust_info.iloc[np.where(idx == -1)[0]] = np.nan

        return cust_info


# %%
# % % time

# dataset = Dataset(tgt_date=pd.to_datetime("2022-10-02"),
#                   scenario_name='TWNA1401',
#                   group='Person',
#                   level='ACC',
#                   bankL=True)

# %%
# dataset.cust_seg_cust_list('CEH2')

# %%
# assert dataset.num_sam_dist(["year", "week"], on="all").警示數量.sum() == sum(
#     dataset.num_sam_dist(["year", "week"], on="all", cust_seg=seg).警示數量.sum() for seg in dataset.seg_mapping["all"])

# %%
# assert dataset.num_sam_dist(["year", "week"], on="vip").警示數量.sum() == sum(
#     dataset.num_sam_dist(["year", "week"], on="vip", cust_seg=seg).警示數量.sum() for seg in dataset.seg_mapping["all"])

# %%
# 'CEH1': ['CE1', 'CH1'],
# 'CML1': ['CM1', 'CL1'],
# 'CEH2': ['CE2', 'CH2'],
# 'CML2': ['CM2', 'CL2'],
# 'CEH3': ['CE3', 'CH3'],
# 'CML3': ['CM3', 'CL3'],
# 'C4': ['CE4', 'CH4', 'CM4', 'CL4']

# display(
#     dataset
#     ._alert_processed
#     .query("客戶編碼 == '.--.770:'")
#     .sort_values("資料日期")
# )

# print(len(dataset.cust_seg_cust_list('CML1')))

# display(
#     dataset
#     ._alert_processed
#     .query("@dataset.train_start_date <= 資料日期 < @dataset.train_end_date")
#     .query("客戶分群 in ['CM1', 'CL1']")
#     .groupby(["客戶編碼"])
#     .agg({"資料日期": "count"})
#     .rename({"資料日期": "警示數量"}, axis=1)
# )

# display(
#     dataset._tgt_group_info
#     .groupby("客戶編碼")
#     .agg({"資料開始週期": "max"})
#     .astype({"資料開始週期": str})
#     .to_records()
#     .tolist()
# )

# display(dataset._alert_processed.客戶編碼.map((
#     dataset._tgt_group_info
#     .astype({"資料開始週期": str})
#     .set_index(["客戶編碼", "資料開始週期"])
#     .loc[(
#         dataset._tgt_group_info
#         .groupby("客戶編碼")
#         .agg({"資料開始週期": "max"})
#         .astype({"資料開始週期": str})
#         .to_records()
#         .tolist()
#     )]
#     .reset_index()
# ).pipe(lambda x: dict(zip(x.客戶編碼, x.客戶分群)))))

# 以現在 tgt_date 的角度有可能沒有客戶是這個 segmentation
# print(dataset.cust_seg_cust_list('CEH2'))

# for z in ['CEH1', 'CML1', 'CEH2', 'CML2', 'CEH3', 'CML3', 'C4']:
#     print(z, np.intersect1d(dataset.vip_list, dataset.cust_seg_cust_list(z)), len(np.intersect1d(dataset.vip_list, dataset.cust_seg_cust_list(z))))

# %% [markdown]
# # training & inference

# %%

class EmptyTraining(Exception):
    pass


class NumSamV1:
    def __init__(self, dataset: Dataset, load=False):
        self.dataset = dataset
        self.cols_dropped = ["資料日期", "資料開始週期", "警示數量", "relation_duration"]
        package_dir = os.path.dirname(os.path.abspath(__file__))
        self.model_path = os.path.join(
            package_dir, f'./models/{dataset.scenario_name}/{dataset.group}')
        self.pics_path = os.path.join(
            package_dir, f'./pics/{dataset.scenario_name}/{dataset.group}')
        # self.model_path = f'./models/{dataset.scenario_name}/{dataset.group}'
        # self.pics_path = f'./pics/{dataset.scenario_name}/{dataset.group}'
        os.makedirs(self.model_path, exist_ok=True)
        os.makedirs(self.pics_path, exist_ok=True)

        if load:
            self.models = [joblib.load(os.path.join(self.model_path, fn))
                           for fn in sorted(os.listdir(self.model_path))]
        else:
            for fn in os.listdir(self.model_path):
                os.remove(os.path.join(self.model_path, fn))
            self.models = []
            self._train_inference = None
            self._train_27_models()
        self.reg = LinearRegression()

        self._train_inference = None
        self._val_inference = None
        self._test_inference = None
        self.cust_seg = None

    def _create_label(self, tgt: str, tgt_week: int):
        tgt_dates = getattr(self.dataset, tgt).資料日期 + \
            timedelta(days=(tgt_week-1)*7+1)
        return (
            getattr(self.dataset, tgt)
            .assign(**tgt_dates.dt.isocalendar(),
                    month=tgt_dates.dt.month,
                    quarter=tgt_dates.dt.quarter,
                    hfy=tgt_dates.dt.quarter.isin([3, 4])+1,
                    bigMon=tgt_dates.dt.month.isin([1, 3, 5, 7, 8, 10, 12]).astype(int))
            .merge(self.dataset.num_sam_dist(["客戶編碼", "year", "week"], on="vip"), how='left')
            .fillna({"警示數量": 0})
        )

    def _create_pipe(self, X_train):
        return make_pipeline(
            TargetEncoder(cols=X_train.columns[X_train.dtypes == 'object'],
                          handle_missing='return_nan'),
            SimpleImputer(missing_values=np.nan, strategy='mean'),
            XGBRegressor()
        )

    def _train_27_models(self):
        for i in range(1, 28):
            train = self._create_label("train", i)
            x_train, y_train = train.drop(
                self.cols_dropped, axis=1), train.警示數量
            # from IPython import embed
            # embed()
            model = self._create_pipe(x_train)
            model.fit(x_train, y_train)
            joblib.dump(model, os.path.join(
                self.model_path, f"{datetime.now().date()}_w{str(i).zfill(2)}.pkl"), compress=3)
            self.models.append(model)

    def _fit_reg(self):
        train_num_df = self.next_week("train")
        if train_num_df.shape[0] == 0:
            raise EmptyTraining('empty train_num_df')

        self.reg.fit(train_num_df[['week1_gt']].values,
                     train_num_df.week1_total.values)

    def inference(self, cust_seg: str):
        logger.info('NumSamV1 inference')
        self.cust_seg = cust_seg
        for tgt in ["train", "val", "test"]:
            y_trues = {}
            y_preds = {}
            for i, model in enumerate(self.models):
                xy = self._create_label(tgt, i)
                x, y = xy.drop(self.cols_dropped, axis=1), xy.警示數量
                y_trues[f"week{i+1}_gt"] = y.values
                y_preds[f"week{i+1}_pred"] = model.predict(x)

            cust_list = self.dataset.cust_seg_cust_list(self.cust_seg)

            setattr(self, f"_{tgt}_inference", (
                getattr(self.dataset, tgt)
                .assign(**y_trues,
                        **y_preds)
                .query("客戶編碼 in @cust_list")
            ))

            if tgt == "train":
                self._fit_reg()

    def mse_for_27_models(self, tgt: str):
        mse = []
        for i in range(1, 28):
            mse.append(mean_squared_error(getattr(self, f"_{tgt}_inference")[f"week{i}_gt"].values,
                                          getattr(self, f"_{tgt}_inference")[f"week{i}_pred"].values))

        return np.array(mse)

    def mae_for_27_models(self, tgt: str):
        mae = []
        for i in range(1, 28):
            mae.append(mean_absolute_error(getattr(self, f"_{tgt}_inference")[f"week{i}_gt"].values,
                                           getattr(self, f"_{tgt}_inference")[f"week{i}_pred"].values))

        return np.array(mae)

    def next_week(self, tgt):
        getattr(self, f"_{tgt}_inference").to_csv(
            os.path.join(snapshot_folder, f"_{tgt}_inference.csv"))
        return (
            getattr(self, f"_{tgt}_inference")
            .pipe(lambda x: x.assign(資料日期=x.資料日期 + timedelta(days=1)))
            .pipe(lambda x: x.assign(**x.filter(regex="week.*_pred").clip(0, 6)))
            .groupby("資料日期")
            .agg({"week1_pred": "sum",
                  "week1_gt": "sum"})
            .pipe(lambda x: (x
                             .assign(week1_total_pred=self._mult_num(x[["week1_pred"]].values),
                                     week1_total=(x
                                                  .index.to_series()
                                                  .dt.isocalendar()
                                                  .merge(self.dataset.num_sam_dist(["year", "week"], on="all", cust_seg=self.cust_seg), how="left")
                                                  .警示數量
                                                  .fillna(0)
                                                  .values))
                             )
                  )
            [["week1_pred", "week1_total_pred", "week1_gt", "week1_total"]]
            .apply(lambda x: np.maximum(x, 0))
            .round()
            .astype(int)
            .pipe(lambda x: x[["week1_pred", "week1_total_pred"]] if tgt == "test" else x)
        )

    def this_period(self, tgt: str, period: str):
        assert period in ["month", "quarter", "hfy", "year"]
        # from IPython import embed
        # embed()

        def gen_key(x, period):
            dic = {"year": x.index.to_series().dt.year}
            if period == "hfy":
                dic.update(
                    {period: x.index.to_series().dt.quarter.isin([3, 4])+1})
            elif period != "year":
                dic.update({period: getattr(x.index.to_series().dt, period)})
            return dic

        if period == "year":
            self.this_period(tgt, "hfy")

        return (
            getattr(self, f"_{tgt}_inference")
            .groupby("資料日期")
            .apply(lambda x: pd.DataFrame.from_dict({f"{period}_pred": [self._this_period_agg(x, period=period, on="vip")],
                                                     f"{period}_total_pred": [self._this_period_agg(x, period=period, on="all")]}))
            .droplevel(1)
            .pipe(lambda x: (x
                             .assign(**{f"{period}_gt": (x
                                                         .assign(**gen_key(x, period))
                                                         .merge(self.dataset.num_sam_dist2(list(set(["year"]+[period])), on="vip", cust_seg=self.cust_seg), how="left")
                                                         .警示數量
                                                         .fillna(0)
                                                         .values)},
                                     **{f"{period}_total": (x
                                                            .assign(**gen_key(x, period))
                                                            .merge(self.dataset.num_sam_dist2(list(set(["year"]+[period])), on="all", cust_seg=self.cust_seg), how="left")
                                                            .警示數量
                                                            .fillna(0)
                                                            .values)})
                             )
                  )
            .round()
            .astype(int)
            .pipe(lambda x: x[[f"{period}_pred", f"{period}_total_pred"]] if tgt == "test" else x)
        )

    def _mult_num(self, x):
        try:
            return self.reg.predict(x)
        except NotFittedError:
            return -1

    def _this_period_agg(self, tgt_date_pred_df, period, on: str, designated_date=""):
        tgt_date = tgt_date_pred_df.name
        if designated_date:
            tgt_date = pd.to_datetime(designated_date)

        if period == "hfy":
            if tgt_date.month >= 7:
                start_period = pd.to_datetime(f"{tgt_date.year}-07-01")
                end_period = pd.to_datetime(f"{tgt_date.year}-12-31")
            else:
                start_period = pd.to_datetime(f"{tgt_date.year}-01-01")
                end_period = pd.to_datetime(f"{tgt_date.year}-06-30")
        else:
            start_period = tgt_date.to_period(
                "M" if period == "month" else "Y" if period == "year" else "Q").start_time
            end_period = tgt_date.to_period(
                "M" if period == "month" else "Y" if period == "year" else "Q").end_time

        sam_sum = 0

        if period == "year":
            sam_sum += self._this_period_agg(tgt_date_pred_df, "hfy", on)
            if tgt_date.quarter in [1, 2]:
                last_year_34 = self._this_period_agg(
                    tgt_date_pred_df, "hfy", on, designated_date=f"{tgt_date.year-1}-12-31")
                last_year_12 = self._this_period_agg(
                    tgt_date_pred_df, "hfy", on, designated_date=f"{tgt_date.year-1}-06-30")
                if last_year_12 > 0:
                    ratio = last_year_34/last_year_12
                    sam_sum *= (ratio+1)
                else:
                    sam_sum += last_year_34
            else:
                sam_sum += self._this_period_agg(
                    tgt_date_pred_df, "hfy", on, designated_date=f"{tgt_date.year}-06-30")
        else:
            if tgt_date > start_period:
                sam_sum += (
                    pd.date_range(
                        start_period, tgt_date - timedelta(days=0 if designated_date else 1)).isocalendar()
                    .merge(self.dataset.num_sam_dist(["year", "week", "day"], on=on, cust_seg=self.cust_seg), how="left")
                    .警示數量
                    .sum()
                )
                if designated_date:
                    return sam_sum

            days_till_end = pd.date_range(tgt_date, end_period).isocalendar()
            num_week = len(days_till_end[["year", "week"]].drop_duplicates())
            if tgt_date <= end_period:
                pred = (
                    tgt_date_pred_df
                    [[f"week{i}_pred" for i in range(1, num_week)]]
                    .apply(lambda x: x.clip(0, 6))
                    .sum().sum()
                ) + (
                    tgt_date_pred_df[f"week{num_week}_pred"]
                    .clip(0, 6)
                    .sum()
                ) * (days_till_end.groupby(["year", "week"]).size().min() / 7.)

                if on == "all":
                    pred = np.maximum(self.reg.predict(
                        np.array([[pred]]))[0], 0)

                sam_sum += pred

        return sam_sum

    def save_png_ret_str(self):
        (
            pd.concat([self.next_week("train"),
                       self.next_week("val")])
            .rename_axis("run_date", axis=0)
        ).plot(marker=".")

        plt.savefig(os.path.join(self.pics_path, "pred_trend.png"))

        with open(os.path.join(self.pics_path, "pred_trend.png"), "rb") as f:
            png_encoded = base64.b64encode(f.read())

        return png_encoded.decode()


# %%
# model = NumSamV1(dataset, load=False)

# %%
# model.inference(cust_seg="all")

# %%
# pd.concat([model.next_week("train"), model.next_week("val")]
#           ).rename_axis("run_date", axis=0).plot(marker=".")
# plt.show()

# %%
# plt.plot(model.mse_for_27_models("train"), label="train")
# plt.plot(model.mse_for_27_models("val"), label="val")
# plt.xlabel("nth model")
# plt.ylabel("M S E")
# plt.legend()
# plt.show()

# %%
# plt.plot(model.mae_for_27_models("train"), label="train")
# plt.plot(model.mae_for_27_models("val"), label="val")
# plt.xlabel("nth model")
# plt.ylabel("M A E")
# plt.legend()

# plt.show()

# %%
# model.this_period("test", "hfy")

# %%

group_dir = {'P': 'Person', 'C': 'Corp'}
# %%
# tgt date不能是周日以外的日期

level_dir = {'TWNA1A01': "PTY", 'TWNAB101': "PTY", 'TWNA1401': "ACC"}


def training(date=datetime.now(), seg_group="P", snro_cd='TWNA1A01', bankL=True):
    logger.info(f'-----------------start training------------------  \
                \nscenario name: {snro_cd}   \
                \ndate: {date}\
                \nseg: {seg_group}\
                \n-------------------------------------------------')
    # try:
    level = level_dir[snro_cd]
    assert date.isocalendar()[2] == 7
    group = group_dir[seg_group]
    # if group == 'all':
    #     group_list = ['Person', 'Corp']
    # else:
    #     group_list = [group]

    result = {'tgt_date': date, 'seg_group': group}
    # 對於需要 retrain 的模型只有兩個 自然人或法人
    # 輸入起始日 tgt_date, 分群 group, 以及 snro_cd 回傳相應的資料
    try:
        dataset = Dataset(date, snro_cd, group, level,
                          bankL=bankL, is_Train=True)
    except NonDataError as err:
        message = str(err) + '\n' + traceback.format_exc()
        result = {}
        result['status_code'] = 0
        result['message'] = message
        return result

    model = NumSamV1(dataset, load=False)
    model.inference(cust_seg="all")
    result['score'] = {
        'MSE': round(np.average(model.mse_for_27_models("val")), 2),
        'MAE': round(np.average(model.mae_for_27_models("val")), 2)
    }
    result['pic'] = model.save_png_ret_str()
    model_score = {}
    for i in range(27):
        tmp_dir = {'MAE': model.mae_for_27_models("val")[i],
                   'MSE': model.mse_for_27_models("val")[i]}
        model_score[i+1] = tmp_dir
    result['individual_model_score'] = model_score
    result['status_code'] = 1
    result['message'] = 'success'
    # except Exception as err:
    #     message = str(err) + '\n' + traceback.format_exc()
    #     result = {'status_code': 0, 'message': message}

    return result


def inference(date=datetime.now(), snro_cd='TWNA1A01', group="all", bankL=True):
    logger.info(f'----------------start inference------------------  \
                \n scenario name: {snro_cd}   \
                \n date: {date}\
                \n seg: {group}\
                \n-------------------------------------------------')
    assert date.isocalendar()[2] == 7
#     segs = {
#         "Person": ['個人高風險','個人非高風險'],
#         "Corp": ['個人高風險','個人非高風險']
#     }
    level = level_dir[snro_cd]
    segs = {
        "Person": ['PML', 'PEH'],
        "Corp": ['CML1', 'CEH1', 'CML2', 'CEH2', 'CML3', 'CEH3', 'C4']
    }

    if group == 'all':
        group_list = ['Person', 'Corp']
    else:
        group_list = [group]

    result = {'tgt_date': date, 'seg_group': {}}

    for grp in group_list:
        try:
            print(f'load {grp}')
            dataset = Dataset(date, snro_cd, group=grp,
                              level=level, bankL=bankL, is_Train=False)

        except NonDataError as err:
            print('NonDataError')
            message = str(err) + '\n' + traceback.format_exc()
            for seg in segs[grp]:
                print(seg, ' nondata')
                result['seg_group'][seg] = {
                    'next_week_predict': 0,
                    'month_predict': 0,
                    'season_predict': 0,
                    'half_year_predict': 0,
                    'year_predict': 0,
                    'status_code': 0,
                    'message': message}
                print(seg, ' nondata done')
            continue

        for seg in segs[grp]:
            logger.info(f' run {seg}')
            if len(dataset.cust_seg_cust_list(seg)) != 0:

                print(
                    f'len(dataset.cust_seg_cust_list(seg)) == {len(dataset.cust_seg_cust_list(seg))}')
                model = NumSamV1(dataset, load=True)
                print(seg)
                try:
                    model.inference(seg)
                    if model.next_week("test").week1_total_pred.shape[0] == 0:

                        raise Nonnextweek('week1_total_pred.shape is empty')
                except EmptyTraining as err:
                    logger.info(f'{err}')
                    result['seg_group'][seg] = {
                        'next_week_predict': 0,
                        'month_predict': 0,
                        'season_predict': 0,
                        'half_year_predict': 0,
                        'year_predict': 0,
                        'status_code': 1,
                        'message': f"EmptyTraining data! {err}"
                    }
                    logger.info(f"no customer in {seg} at tgt_date!")
                    continue
                except Nonnextweek as err:
                    logger.info('len(dataset.cust_seg_cust_list(seg)) == 0')
                    result['seg_group'][seg] = {
                        'next_week_predict': 0,
                        'month_predict': 0,
                        'season_predict': 0,
                        'half_year_predict': 0,
                        'year_predict': 0,
                        'status_code': 1,
                        'message': f"no customer in {seg} at tgt_date!{err}"
                    }
                    logger.info(f"no customer in {seg} at tgt_date!")
                    continue
                except Exception as err:
                    raise

                result['seg_group'][seg] = {
                    'next_week_predict': model.next_week("test").week1_total_pred[0],
                    'month_predict': model.this_period("test", "month").month_total_pred[0],
                    'season_predict': model.this_period("test", "quarter").quarter_total_pred[0],
                    'half_year_predict': model.this_period("test", "hfy").hfy_total_pred[0],
                    'year_predict': model.this_period("test", "year").year_total_pred[0],
                    'status_code': 1,
                    'message': 'success'
                }
            else:
                logger.info('len(dataset.cust_seg_cust_list(seg)) == 0')
                result['seg_group'][seg] = {
                    'next_week_predict': 0,
                    'month_predict': 0,
                    'season_predict': 0,
                    'half_year_predict': 0,
                    'year_predict': 0,
                    'status_code': 1,
                    'message': f"len(dataset.cust_seg_cust_list(seg)) == 0! "
                }
                logger.info(f"no customer in {seg} at tgt_date!")
        # except Exception as err:
        #     message = str(err) + '\n' + traceback.format_exc()
        #     result = {'status_code': 0, 'message': message}
    return result


# %%


class P(pprint.PrettyPrinter):
    def _format(self, object, *args, **kwargs):
        if isinstance(object, str):
            if len(object) > 20:
                object = object[:20] + '...'
        return pprint.PrettyPrinter._format(self, object, *args, **kwargs)


# %%
if __name__ == "__main__":
    import json
    with open(os.path.join(snapshot_folder, 'final_result.txt'), 'w', encoding='utf-8') as g:

        the_date = "2022-10-16"
        for i in range(1):
            print(the_date)
            for scenario in ['TWNA1A01', 'TWNAB101', 'TWNA1401']:

                g.write(f'\n{scenario} P start training  the_date: {the_date}')
                a_time = datetime.now()
                TWNA1A01_training_out = training(date=pd.to_datetime(the_date),
                                                 snro_cd=scenario,
                                                 seg_group='P',
                                                 bankL=True)
                TWNA1A01_training_out['tgt_date'] = str(
                    TWNA1A01_training_out['tgt_date'])
                with open(os.path.join(snapshot_folder, f'{the_date}_{scenario}_P_train_result.json'), 'w', encoding='utf-8') as h:
                    json.dump(TWNA1A01_training_out, h)
                g.write(
                    f'\n{scenario} P training finished  use time {(datetime.now()-a_time).total_seconds()}')
                g.write(f"\n MSE: {TWNA1A01_training_out['score']['MSE']}")
                g.write(f"\n MSE: {TWNA1A01_training_out['score']['MSE']}")

                g.write(f'\n{scenario} C start training  the_date: {the_date}')
                a_time = datetime.now()
                TWNA1A01_training_out = training(date=pd.to_datetime(the_date),
                                                 snro_cd=scenario,
                                                 seg_group='C',
                                                 bankL=True)
                TWNA1A01_training_out['tgt_date'] = str(
                    TWNA1A01_training_out['tgt_date'])
                with open(os.path.join(snapshot_folder, f'{the_date}_{scenario}_C_train_result.json'), 'w', encoding='utf-8') as h:
                    json.dump(TWNA1A01_training_out, h)
                g.write(
                    f'\n{scenario} P training finished. use time {(datetime.now()-a_time).total_seconds()}')
                g.write(f"\n MSE: {TWNA1A01_training_out['score']['MSE']}")
                g.write(f"\n MSE: {TWNA1A01_training_out['score']['MSE']}")

                g.write(f'\n{scenario} start inference  the_date: {the_date}')
                a_time = datetime.now()
                TWNA1A01_inference_out = inference(date=pd.to_datetime(the_date),
                                                   snro_cd=scenario,
                                                   group='all',
                                                   bankL=True)
                TWNA1A01_inference_out['tgt_date'] = str(
                    TWNA1A01_inference_out['tgt_date'])
                for seg in TWNA1A01_inference_out['seg_group'].keys():
                    for x2 in TWNA1A01_inference_out['seg_group'][seg].keys():
                        TWNA1A01_inference_out['seg_group'][seg][x2] = str(
                            TWNA1A01_inference_out['seg_group'][seg][x2])
                with open(os.path.join(snapshot_folder, f'{the_date}_{scenario}_inference_result.json'), 'w', encoding='utf-8') as h:
                    json.dump(TWNA1A01_inference_out, h)

                g.write(
                    f'\n{scenario} inference finished. use time {(datetime.now()-a_time).total_seconds()}.\n the_date: {the_date}\n {TWNA1A01_inference_out}')

            the_date = datetime.strptime(
                the_date, '%Y-%m-%d').date() + timedelta(days=7)
            the_date = the_date.strftime('%Y-%m-%d')
