In [None]:
%run analyses/imports.py

### 손실 추이 분석
### TODO
- `highlight_<...>_titles`를 N개 받아 각자 다른 색상으로 처리할 수 있도록 수정 (up to v0.1.0)
- `plot_losses_overlap`이 `loss_dict`를 받아 처리할 수 있도록 수정 (up to v0.1.0)
### 사용 방법
- 'TSB-AD/results/loss' 아래 모델 이름으로 손실 추이가 .npy 포맷으로 저장되어 있음
- ⭐ `your_root_dir` 경로 지정 필수 (imports.py에서 수행)
- `keywords_to_include` 리스트에 디렉터리를 가리키는 키워드 저장
- `file_paths`에 해당 디렉터리들의 경로 저장

##### 참고용 디렉터리 구조
```
loss/
    |- 025_ParallelSNN_receptive_None/ (모델명)
        |- 001_Genesis_id_1_Sensor_tr_4055_1st_15538_train.npy (손실 추이, ndarray)
        |- ...
    |- ...
```

In [None]:
src_dir_path = f'{your_root_dir}/TSB-AD/results/loss'
keywords_to_include = ['025', '026']
keywords_to_exclude = []

In [None]:
path_loader = PathLoader(src_dir_path, keywords_to_include)

In [None]:
file_paths = path_loader.get_file_paths()
print(len(file_paths))
for file_path in file_paths:
    print(file_path)

In [None]:
def load_train_valid_loss(target_dir_path):
    import collections

    temp_dict = collections.defaultdict(dict)

    for file_name in os.listdir(target_dir_path):
        base = file_name.split('.')[0]
        key = base.split('_')[0] + '_' + base.split('_')[1]
        if base.endswith('train'):
            temp_dict[key]['train'] = np.load(os.path.join(target_dir_path, file_name))
        elif base.endswith('valid'):
            temp_dict[key]['valid'] = np.load(os.path.join(target_dir_path, file_name))

    loss_dict = {
        key: (values.get('train'), values.get('valid'))
        for key, values in temp_dict.items()
    }

    # sort by key
    sorted_loss_dict = dict(sorted(loss_dict.items(), key=lambda item: item[0]))
    return sorted_loss_dict

In [None]:
def plot_all_losses(losses, title, highlight_red_titles=None, highlight_blue_titles=None):
    num_ts = len(losses) # 2 or 29 or 180

    if num_ts == 2:
        fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    elif num_ts == 29:
        fig, axs = plt.subplots(5, 6, figsize=(18, 12))
        axs = axs.flatten()
    elif num_ts == 180:
        fig, axs = plt.subplots(30, 6, figsize=(18, 54))
        axs = axs.flatten()

    for idx, (k, v) in enumerate(losses.items()):
        ts_name = k
        train_loss = v[0]
        valid_loss = v[1]

        axs[idx].plot(train_loss, label='Train Loss', color='blue')
        axs[idx].plot(valid_loss, label='Valid Loss', color='orange')
        
        if highlight_red_titles is not None or highlight_blue_titles is not None:
            f_red = False
            f_blue = False
            if highlight_blue_titles is not None and ts_name.split('_')[0] in highlight_blue_titles:
                f_blue = True
            if highlight_red_titles is not None and ts_name.split('_')[0] in highlight_red_titles:
                f_red = True
            if f_red and f_blue:
                title_color = 'purple'
            elif f_red:
                title_color = 'red'
            elif f_blue:
                title_color = 'blue'
            else:
                title_color = 'black'
            axs[idx].set_title(ts_name, color=title_color)
        else:
            axs[idx].set_title(ts_name)

    # Off the axes that are not used
    for i in range(len(losses), len(axs)):
        axs[i].axis('off')

    # Set the overall title
    plt.suptitle(title, fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.975])
    plt.show()

In [None]:
def plot_losses_overlap(losses_list, title=None, highlight_red_titles=None, highlight_blue_titles=None):
    num_ts = len(list(losses_list[0][1].items()))

    if num_ts == 2:
        fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    elif num_ts == 29:
        fig, axs = plt.subplots(5, 6, figsize=(18, 12))
        axs = axs.flatten()
    elif num_ts == 180:
        fig, axs = plt.subplots(30, 6, figsize=(18, 54))
        axs = axs.flatten()

    num_model = len(losses_list)
    base_colors = sns.color_palette("tab10", n_colors=num_model)
    names = []
    for model_idx, (name, losses) in enumerate(losses_list):
        # 1) base color 추출
        r, g, b = base_colors[model_idx]
        # 2) HLS 변환
        h, l, s = colorsys.rgb_to_hls(r, g, b)
        # 3) 채도(boost) → 1.0 초과 금지
        s_boost = min(s * 6.0, 1.0)
        valid_col = colorsys.hls_to_rgb(h, l, s_boost)
        train_col = tuple(np.array(valid_col) * 0.7 + 0.3)
        for idx, (k, v) in enumerate(losses.items()):
            ts_name = k
            train_loss = v[0]
            valid_loss = v[1]

            axs[idx].plot(train_loss, label='Train Loss', color=train_col, alpha=0.5)
            axs[idx].plot(valid_loss, label='Valid Loss', color=valid_col, alpha=0.5, linestyle='-.')

            # Set the title color based on the highlight titles
            if highlight_red_titles is not None or highlight_blue_titles is not None:
                f_red = False
                f_blue = False
                if highlight_blue_titles is not None and ts_name.split('_')[0] in highlight_blue_titles:
                    f_blue = True
                if highlight_red_titles is not None and ts_name.split('_')[0] in highlight_red_titles:
                    f_red = True
                if f_red and f_blue:
                    title_color = 'purple'
                elif f_red:
                    title_color = 'red'
                elif f_blue:
                    title_color = 'blue'
                else:
                    title_color = 'black'
                axs[idx].set_title(ts_name, color=title_color)
            else:
                axs[idx].set_title(ts_name)
        names.append(name)

    # Off the axes that are not used
    for i in range(len(losses), len(axs)):
        axs[i].axis('off')

    
    patches = [
        mpatches.Patch(color=base_colors[i], label=names[i]) for i in range(num_model)
    ]
    fig.legend(handles=patches, loc='upper left', ncol=num_model, bbox_to_anchor=(0.025, 0.98))

    # Set the overall title
    plt.suptitle(title, fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.98])
    plt.show()

In [None]:
loss_dict = {}
file_names = path_loader.get_file_names()
for f_name, f_path in zip(file_names, file_paths):
    loss_dict[f_name] = load_train_valid_loss(f_path)

In [None]:
for k in loss_dict.keys():
    print(k)

#### 시각화
- `highlight_<...>_titles` 리스트에 문자열로 시계열의 세 자리 ID를 입력하면 해당 손실 추이의 소제목을 강조할 수 있음
- `plot_all_losses()` 메서드는 한 모델의 모든 손실 추이를 시각화
- `plot_losses_overlap()` 메서드는 모든 모델의 손실 추이를 각 서브플롯에 겹쳐서 시각화

In [None]:
highlight_red_titles = ['065', '073', '078', '115', '130', '144', '173']
highlight_blue_titles = ['032', '034', '035', '037', '038', '039', '041', '043', '044', '045', '046', '047', '048', '051', '052', '053', '054', '055']

In [None]:
plot_all_losses(loss_dict['025_ParallelSNN_receptive_None'], title='ParallelSNN with Receptive Encoder', highlight_red_titles=highlight_red_titles, highlight_blue_titles=highlight_blue_titles)

In [None]:
#plot_losses_overlap(losses_list, title='Comparison by overlapping', highlight_red_titles=highlight_red_titles, highlight_blue_titles=highlight_blue_titles)