In [5]:
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import interact, widgets, FloatSlider, IntSlider, HBox, VBox, Layout, Output
from IPython.display import display
import io
import base64
from IPython.display import HTML, display
import os

# 读取文本文件
def read_data(file_path='text-result-st+st-0430.txt'):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                # 使用正则表达式提取文件名中的a和b值以及TC和TD值
                match = re.search(r'./emb_a(\d+)b(\d+)\.npy, TC = ([\d\.]+), TD = ([\d\.]+)', line)
                if match:
                    a_val = int(match.group(1))
                    b_val = int(match.group(2))
                    tc_val = float(match.group(3))
                    td_val = float(match.group(4))
                    data.append({'a': a_val, 'b': b_val, 'TC': tc_val, 'TD': td_val})
    return pd.DataFrame(data)

# 读取数据
df = read_data()

# 数据预处理：检查是否有重复的(a,b)组合，如果有，取平均值
df = df.groupby(['a', 'b']).mean().reset_index()

# 计算TC+TD之和
df['TC+TD'] = df['TC'] + df['TD']

# 创建网格数据(为热力图做准备)
def create_grid_data(df, metric):
    a_values = sorted(df['a'].unique())
    b_values = sorted(df['b'].unique())
    
    # 创建网格
    grid = np.zeros((len(a_values), len(b_values)))
    
    # 填充网格
    for i, a in enumerate(a_values):
        for j, b in enumerate(b_values):
            selection = df[(df['a'] == a) & (df['b'] == b)]
            if not selection.empty:
                grid[i, j] = selection[metric].values[0]
            else:
                grid[i, j] = np.nan
                
    return grid, a_values, b_values

# 创建输出区域
output_area = Output()
current_fig = None  # 用于存储当前图表

# 保存带有统计信息的图表功能 - 改进版
def save_figure_with_stats(filename="figure_with_stats.png", filtered_data=None):
    """Save current figure with statistical information based on filtered data"""
    global current_fig
    
    if current_fig is None:
        print("No figure available to save")
        return
    
    # Use filtered data if provided, otherwise use the full dataset
    data_to_use = filtered_data if filtered_data is not None else df
    
    # Get maximum values and their corresponding a,b
    max_tc = data_to_use.loc[data_to_use['TC'].idxmax()]
    max_td = data_to_use.loc[data_to_use['TD'].idxmax()]
    max_sum = data_to_use.loc[data_to_use['TC+TD'].idxmax()]
    
    # Prepare statistical information text
    stats_text = (
        f"Maximum Metrics Summary (from current filtered data):\n"
        f"Max TC: {max_tc['TC']:.4f}, at a={max_tc['a']}, b={max_tc['b']}\n"
        f"Max TD: {max_td['TD']:.4f}, at a={max_td['a']}, b={max_td['b']}\n"
        f"Max TC+TD: {max_sum['TC+TD']:.4f}, at a={max_sum['a']}, b={max_sum['b']}"
    )
    
    # Add text directly to current figure bottom
    text_obj = current_fig.text(0.5, 0.01, stats_text, ha='center', 
                     bbox={'facecolor': 'yellow', 'alpha': 0.7, 'pad': 10, 'boxstyle': 'round'},
                     fontsize=12, transform=current_fig.transFigure)
    
    # Add bottom margin to accommodate text
    current_fig.tight_layout(rect=[0, 0.15, 1, 1])
    
    # Save figure
    current_fig.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Figure and statistics saved to: {filename}")
    
    # Clean up - remove added text and restore original layout
    text_obj.remove()
    current_fig.tight_layout()
    
    return filename

# 创建交互式可视化函数
def plot_heatmaps(tc_threshold=0.0, td_threshold=0.0, tc_upper=1.0, td_upper=1.0, sum_threshold=0.0, sum_upper=2.0, a_min=0, a_max=100, b_min=0, b_max=100):
    global current_fig
    
    # 根据条件过滤数据
    filtered_df = df[(df['TC'] >= tc_threshold) & (df['TC'] <= tc_upper) & 
                     (df['TD'] >= td_threshold) & (df['TD'] <= td_upper) &
                     (df['TC+TD'] >= sum_threshold) & (df['TC+TD'] <= sum_upper) &
                     (df['a'] >= a_min) & (df['a'] <= a_max) &
                     (df['b'] >= b_min) & (df['b'] <= b_max)]
    
    with output_area:
        output_area.clear_output(wait=True)
        
        if filtered_df.empty:
            print("No data matches the criteria")
            current_fig = None
            return
        
        # 创建垂直排列的三个图表
        fig, axes = plt.subplots(3, 1, figsize=(12, 24))
        current_fig = fig
        
        # 绘制TC热力图
        tc_grid, a_values, b_values = create_grid_data(filtered_df, 'TC')
        
        sns.heatmap(tc_grid, 
                    annot=True, 
                    fmt=".3f",
                    cmap="YlGnBu", 
                    ax=axes[0],
                    xticklabels=b_values,
                    yticklabels=a_values,
                    annot_kws={"size": 10})  # 增大标注文字大小
        
        axes[0].set_title('TC Heatmap', fontsize=20)
        axes[0].set_xlabel('b value', fontsize=16)
        axes[0].set_ylabel('a value', fontsize=16)
        axes[0].tick_params(labelsize=12)  # 增大刻度标签文字大小
        
        # 绘制TD热力图
        td_grid, a_values, b_values = create_grid_data(filtered_df, 'TD')
        
        sns.heatmap(td_grid, 
                    annot=True, 
                    fmt=".3f",
                    cmap="YlOrRd", 
                    ax=axes[1],
                    xticklabels=b_values,
                    yticklabels=a_values,
                    annot_kws={"size": 10})  # 增大标注文字大小
        
        axes[1].set_title('TD Heatmap', fontsize=20)
        axes[1].set_xlabel('b value', fontsize=16)
        axes[1].set_ylabel('a value', fontsize=16)
        axes[1].tick_params(labelsize=12)  # 增大刻度标签文字大小
        
        # 绘制TC+TD之和热力图
        sum_grid, a_values, b_values = create_grid_data(filtered_df, 'TC+TD')
        
        sns.heatmap(sum_grid, 
                    annot=True, 
                    fmt=".3f",
                    cmap="viridis", 
                    ax=axes[2],
                    xticklabels=b_values,
                    yticklabels=a_values,
                    annot_kws={"size": 10})  # 增大标注文字大小
        
        axes[2].set_title('TC+TD Sum Heatmap', fontsize=20)
        axes[2].set_xlabel('b value', fontsize=16)
        axes[2].set_ylabel('a value', fontsize=16)
        axes[2].tick_params(labelsize=12)  # 增大刻度标签文字大小
        
        plt.tight_layout(pad=3.0)  # 增加子图之间的间距
        plt.show()
        
        # 显示最大值及其对应的a,b
        max_tc = filtered_df.loc[filtered_df['TC'].idxmax()]
        max_td = filtered_df.loc[filtered_df['TD'].idxmax()]
        max_sum = filtered_df.loc[filtered_df['TC+TD'].idxmax()]
        
        print(f"Max TC: {max_tc['TC']:.4f}, at a={max_tc['a']}, b={max_tc['b']}")
        print(f"Max TD: {max_td['TD']:.4f}, at a={max_td['a']}, b={max_td['b']}")
        print(f"Max TC+TD: {max_sum['TC+TD']:.4f}, at a={max_sum['a']}, b={max_sum['b']}")
        
        # 添加保存带统计信息的按钮
        save_btn = widgets.Button(description="Save Figure with Stats")
        save_filename = f"heatmaps_with_stats_tc{tc_threshold:.2f}-{tc_upper:.2f}_td{td_threshold:.2f}-{td_upper:.2f}.png"
        save_btn.on_click(lambda b: save_figure_with_stats(save_filename, filtered_df))
        display(save_btn)

# 创建折线图函数
def plot_line(fixed_param='a', fixed_value=10, metric='TC'):
    global current_fig
    
    with output_area:
        output_area.clear_output(wait=True)
        
        plt.figure(figsize=(14, 10))
        fig = plt.gcf()
        current_fig = fig
        
        if fixed_param == 'a':
            # 固定a值，观察b值变化对指标的影响
            filtered_data = df[df['a'] == fixed_value]
            if filtered_data.empty:
                print(f"No data with a={fixed_value}")
                current_fig = None
                return
            
            # 排序以确保折线图顺序正确
            filtered_data = filtered_data.sort_values(by='b')
            
            plt.plot(filtered_data['b'], filtered_data[metric], marker='o', linewidth=3, markersize=10)
            plt.xlabel('b value', fontsize=16)
            plt.title(f'Impact of b on {metric} (fixed a={fixed_value})', fontsize=20)
        else:
            # 固定b值，观察a值变化对指标的影响
            filtered_data = df[df['b'] == fixed_value]
            if filtered_data.empty:
                print(f"No data with b={fixed_value}")
                current_fig = None
                return
            
            # 排序以确保折线图顺序正确
            filtered_data = filtered_data.sort_values(by='a')
            
            plt.plot(filtered_data['a'], filtered_data[metric], marker='o', linewidth=3, markersize=10)
            plt.xlabel('a value', fontsize=16)
            plt.title(f'Impact of a on {metric} (fixed b={fixed_value})', fontsize=20)
        
        plt.ylabel(metric, fontsize=16)
        plt.grid(True, alpha=0.3)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        
        plt.show()
        
        # 添加保存带统计信息的按钮
        save_btn = widgets.Button(description="Save Figure with Stats")
        save_filename = f"line_{metric}_{fixed_param}{fixed_value}_with_stats.png"
        save_btn.on_click(lambda b: save_figure_with_stats(save_filename, filtered_data))
        display(save_btn)

# 生成散点图函数
def plot_scatter(metric='TC'):
    global current_fig
    
    with output_area:
        output_area.clear_output(wait=True)
        
        plt.figure(figsize=(14, 12))
        fig = plt.gcf()
        current_fig = fig
        
        scatter = plt.scatter(df['a'], df['b'], c=df[metric], cmap='viridis', 
                              alpha=0.8, s=120, edgecolors='black')
        
        plt.title(f'Impact of a and b on {metric}', fontsize=20)
        plt.xlabel('a value', fontsize=16)
        plt.ylabel('b value', fontsize=16)
        cbar = plt.colorbar(scatter, label=metric)
        cbar.ax.tick_params(labelsize=14)  # 设置颜色条的刻度标签大小
        plt.grid(True, alpha=0.3)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        
        # 添加最大值标记
        max_idx = df[metric].idxmax()
        max_a = df.loc[max_idx, 'a']
        max_b = df.loc[max_idx, 'b']
        max_val = df.loc[max_idx, metric]
        
        plt.scatter([max_a], [max_b], c='red', s=250, marker='*', edgecolors='black')
        plt.annotate(f'Max: {max_val:.4f}\n(a={max_a}, b={max_b})', 
                     xy=(max_a, max_b), xytext=(max_a+5, max_b+5),
                     arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
                     fontsize=14)
        
        plt.show()
        
        # 添加保存带统计信息的按钮
        save_btn = widgets.Button(description="Save Figure with Stats")
        save_filename = f"scatter_{metric}_with_stats.png"
        save_btn.on_click(lambda b: save_figure_with_stats(save_filename, df))
        display(save_btn)

# 分析a和b与TC和TD的相关性
def analyze_correlations():
    global current_fig
    
    with output_area:
        output_area.clear_output(wait=True)
        
        # 计算相关性系数
        print("Correlation Analysis:")
        corr = df[['a', 'b', 'TC', 'TD', 'TC+TD']].corr()
        print(corr)
        
        # 绘制相关性热力图
        plt.figure(figsize=(12, 10))
        fig = plt.gcf()
        current_fig = fig
        
        sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, annot_kws={"size": 14})
        plt.title('Variable Correlation Heatmap', fontsize=20)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        
        plt.show()
        
        # 添加保存带统计信息的按钮
        save_btn = widgets.Button(description="Save Figure with Stats")
        save_filename = "correlation_heatmap_with_stats.png"
        save_btn.on_click(lambda b: save_figure_with_stats(save_filename, df))
        display(save_btn)

# 显示原始数据浏览器
def display_data_browser():
    with output_area:
        output_area.clear_output(wait=True)
        display(df)

# 创建交互式热力图控件
def create_heatmap_controls():
    tc_min = FloatSlider(min=0.4, max=0.7, step=0.01, value=0.4, description='TC Min:')
    td_min = FloatSlider(min=0.4, max=1.0, step=0.01, value=0.4, description='TD Min:')
    tc_max = FloatSlider(min=0.5, max=1.0, step=0.01, value=0.8, description='TC Max:')
    td_max = FloatSlider(min=0.5, max=1.0, step=0.01, value=0.8, description='TD Max:')
    sum_min = FloatSlider(min=0.8, max=1.5, step=0.05, value=1.0, description='Sum Min:')
    sum_max = FloatSlider(min=1.0, max=2.0, step=0.05, value=1.8, description='Sum Max:')
    a_min = IntSlider(min=5, max=95, step=5, value=5, description='a Min:')
    a_max = IntSlider(min=5, max=95, step=5, value=95, description='a Max:')
    b_min = IntSlider(min=5, max=95, step=5, value=5, description='b Min:')
    b_max = IntSlider(min=5, max=95, step=5, value=95, description='b Max:')
    
    plot_btn = widgets.Button(description="Generate Heatmaps")
    
    def on_plot_click(b):
        plot_heatmaps(
            tc_min.value, td_min.value, tc_max.value, td_max.value,
            sum_min.value, sum_max.value, a_min.value, a_max.value,
            b_min.value, b_max.value
        )
    
    plot_btn.on_click(on_plot_click)
    
    controls = VBox([
        HBox([tc_min, tc_max]),
        HBox([td_min, td_max]),
        HBox([sum_min, sum_max]),
        HBox([a_min, a_max]),
        HBox([b_min, b_max]),
        plot_btn
    ])
    
    return controls

# 创建交互式折线图控件
def create_line_controls():
    fixed_param = widgets.Dropdown(
        options=['a', 'b'],
        value='a',
        description='Fix Param:',
        layout=Layout(width='200px')
    )
    
    fixed_value = widgets.IntSlider(
        min=min(min(df['a']), min(df['b'])),
        max=max(max(df['a']), max(df['b'])),
        step=5,
        value=10,
        description='Fixed Value:',
        layout=Layout(width='400px')
    )
    
    metric_dropdown = widgets.Dropdown(
        options=['TC', 'TD', 'TC+TD'],
        value='TC',
        description='Metric:',
        layout=Layout(width='200px')
    )
    
    plot_btn = widgets.Button(description="Generate Line Plot")
    
    def on_button_click(b):
        plot_line(fixed_param.value, fixed_value.value, metric_dropdown.value)
    
    plot_btn.on_click(on_button_click)
    
    controls = VBox([
        HBox([fixed_param, fixed_value]),
        HBox([metric_dropdown, plot_btn])
    ])
    
    return controls

# 创建交互式散点图控件
def create_scatter_controls():
    metric_dropdown = widgets.Dropdown(
        options=['TC', 'TD', 'TC+TD'],
        value='TC',
        description='Metric:',
        layout=Layout(width='200px')
    )
    
    plot_btn = widgets.Button(description="Generate Scatter Plot")
    
    def on_button_click(b):
        plot_scatter(metric_dropdown.value)
    
    plot_btn.on_click(on_button_click)
    
    controls = HBox([metric_dropdown, plot_btn])
    
    return controls

# 创建工具栏
def create_toolbar():
    heatmap_btn = widgets.Button(description="Heatmaps", layout=Layout(width='120px'))
    line_btn = widgets.Button(description="Line Plots", layout=Layout(width='120px'))
    scatter_btn = widgets.Button(description="Scatter Plots", layout=Layout(width='120px'))
    corr_btn = widgets.Button(description="Correlation", layout=Layout(width='120px'))
    data_btn = widgets.Button(description="View Data", layout=Layout(width='120px'))
    
    controls_area = Output()
    
    def show_heatmap_controls(b):
        with controls_area:
            controls_area.clear_output(wait=True)
            display(create_heatmap_controls())
    
    def show_line_controls(b):
        with controls_area:
            controls_area.clear_output(wait=True)
            display(create_line_controls())
    
    def show_scatter_controls(b):
        with controls_area:
            controls_area.clear_output(wait=True)
            display(create_scatter_controls())
    
    def show_corr_controls(b):
        with controls_area:
            controls_area.clear_output(wait=True)
            corr_btn = widgets.Button(description="Run Correlation Analysis")
            corr_btn.on_click(lambda b: analyze_correlations())
            display(corr_btn)
    
    def show_data_controls(b):
        with controls_area:
            controls_area.clear_output(wait=True)
            data_view_btn = widgets.Button(description="View Raw Data")
            data_view_btn.on_click(lambda b: display_data_browser())
            display(data_view_btn)
    
    heatmap_btn.on_click(show_heatmap_controls)
    line_btn.on_click(show_line_controls)
    scatter_btn.on_click(show_scatter_controls)
    corr_btn.on_click(show_corr_controls)
    data_btn.on_click(show_data_controls)
    
    # 显示工具栏和控件区域
    toolbar = HBox([heatmap_btn, line_btn, scatter_btn, corr_btn, data_btn])
    display(toolbar)
    display(controls_area)
    
    # 默认显示热力图控件
    show_heatmap_controls(None)
    
    # 显示输出区域
    display(output_area)

# 显示界面
create_toolbar()

HBox(children=(Button(description='Heatmaps', layout=Layout(width='120px'), style=ButtonStyle()), Button(descr…

Output()

Output()

Figure and statistics saved to: heatmaps_with_stats_tc0.40-0.80_td0.40-0.80.png
