In [None]:
# import warnings
# warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib.font_manager")

import importlib
import json
import os
from collections import defaultdict
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display

from fusion_bench.utils.json import load_from_json
import matplotlib.ticker as ticker

from plot_utils import TASK_TO_LABEL_MAPPING
from plot_utils import v2_colors as COLORS
from plot_utils import v2_colors as COLORS_light
from plot_utils import extra_darker_v2_colors as COLORS_dark


plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["mathtext.fontset"] = "cm"

matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

PROJECT_ROOT = Path(
    os.path.abspath(
        os.path.join(importlib.import_module("fusion_bench").__path__[0], "..")
    )
)

In [None]:
import json
from pathlib import Path
import pandas as pd

def load_json(path):
    with open(path, 'r') as f:
        return json.load(f)

def compute_bwt_and_acc(version_dir: str) -> pd.DataFrame:
    """
    For a given version directory (e.g., version_0), compute:
      - average accuracy from `report.json`
      - BWT using all `report_*.json` files (last task accuracy diff)
    Returns:
        pd.DataFrame with columns: ['version', 'acc', 'bwt'] in percentage
    """
    version_dir = Path(version_dir)
    reports = sorted(version_dir.glob("report_*.json"), key=lambda x: int(x.stem.split("_")[1]))

    bwt_sum = 0
    bwt_n = 0

    for i, report_path in enumerate(reports[:-1]):  # exclude final
        data_i = load_json(report_path)
        data_T = load_json(reports[-1])  # final merged model

        # Identify last task in report_i
        task_keys = [k for k in data_i.keys() if k not in {"model_info", "average"}]
        last_task = task_keys[i]

        if not len(task_keys)==i+1: 
            print(f'task num: {len(task_keys)}, report_path: {report_path}')

        acc_i = data_i[last_task]["accuracy"]
        acc_T = data_T[last_task]["accuracy"]

        bwt_sum += acc_T - acc_i
        bwt_n += 1

    bwt = bwt_sum / bwt_n if bwt_n > 0 else 0

    # Load average acc from final report.json
    acc_data = load_json(version_dir / "report.json")
    acc = acc_data["average"]["accuracy"]

    return pd.DataFrame([{
        "version": version_dir.name,
        "acc": acc * 100,
        "bwt": bwt * 100
    }])


def compute_task_acc(version_dir: str) -> pd.DataFrame:
    """
    For a given version directory (e.g., version_0), compute:
      - average accuracy from `report.json`
      - BWT using all `report_*.json` files (last task accuracy diff)
    Returns:
        pd.DataFrame with columns: ['version', 'acc', 'bwt'] in percentage
    """
    version_dir = Path(version_dir)
    final_report = sorted(version_dir.glob("report_*.json"), key=lambda x: int(x.stem.split("_")[1]))[-1]
    task_acc_dict = load_json(final_report)

    del task_acc_dict['model_info']
    del task_acc_dict['average']

    return pd.DataFrame([{key: task_acc_dict[key]['accuracy'] for key in task_acc_dict}])



In [None]:
base_dir = Path("/data1/zihuanqiu/mingle/outputs")
exp_names = [
    "weight_average",
    "continual_task_arithmetic",
    "continual_ties_merging",
    "magmax",
    "consensus_ta",
    "opcm",
    "c_adamerging",
    "c_wemoe",
    "mingle",
    "mingle_star"
]

summary_results = []

for exp in exp_names:
    version_root = base_dir / exp / "vit-b-32-TA8"
    version_dirs = sorted(version_root.glob("version_*"))
    if not version_dirs:
        print(f"⚠️ No version_* found in: {version_root}")
        continue

    df = pd.concat([compute_bwt_and_acc(v) for v in version_dirs], ignore_index=True)
    print(df)

    acc_mean = df["acc"].mean()
    acc_std = df["acc"].std()
    bwt_mean = df["bwt"].mean()
    bwt_std = df["bwt"].std()

    summary_results.append({
        "exp": exp,
        "acc": acc_mean,
        "acc_std": acc_std,
        "bwt": bwt_mean,
        "bwt_std": bwt_std,
    })

final_df = pd.DataFrame(summary_results)
print(final_df)

In [None]:

def compute_bwt_and_acc_nlp(version_dir: str) -> pd.DataFrame:
    """
    For a given version directory (e.g., version_0), compute:
      - average accuracy from `report.json`
      - BWT using all `report_*.json` files (last task accuracy diff)
    Returns:
        pd.DataFrame with columns: ['version', 'acc', 'bwt'] in percentage
    """
    version_dir = Path(version_dir)
    reports = sorted(version_dir.glob("report_*.json"), key=lambda x: int(x.stem.split("_")[1]))

    bwt_sum = 0
    bwt_n = 0

    for i, report_path in enumerate(reports[:-1]):  # exclude final
        data_i = load_json(report_path)
        data_T = load_json(reports[-1])  # final merged model

        # Identify last task in report_i
        task_keys = [k for k in data_i.keys() if k not in {"model_info", "average"}]
        last_task = task_keys[i]

        if not len(task_keys)==i+1: 
            print(f'task num: {len(task_keys)}, report_path: {report_path}')

        acc_i = list(data_i[last_task].values())[-1]
        acc_T = list(data_T[last_task].values())[-1]

        bwt_sum += acc_T - acc_i
        bwt_n += 1

    bwt = bwt_sum / bwt_n if bwt_n > 0 else 0

    # Load average acc from final report.json
    task_keys = [k for k in data_T.keys() if k not in {"model_info", "average"}]
    acc = np.mean([list(data_T[task].values())[-1] for task in task_keys])

    return pd.DataFrame([{
        "version": version_dir.name,
        "acc": acc * 100,
        "bwt": bwt * 100
    }])

# 定义主目录和子实验路径
base_dir = Path("/data1/zihuanqiu/mingle/outputs")
exp_names = [
    "c_adamerging_nlp",
    "c_wemoe_nlp",
    "ties_nlp",
    "ta_nlp",
    "opcm_nlp",
    "mingle_nlp"
]

summary_results = []

for exp in exp_names:
    version_root = base_dir / exp / "t5-base"
    version_dirs = sorted(version_root.glob("version_*"))
    # version_dirs = [version_root / "version_1"]
    if not version_dirs:
        print(f"⚠️ No version_* found in: {version_root}")
        continue

    # 聚合多个 version
    df = pd.concat([compute_bwt_and_acc_nlp(v) for v in version_dirs], ignore_index=True)
    print(df)

    # 提取统计量（均值和标准差）
    acc_mean = df["acc"].mean()
    acc_std = df["acc"].std()
    bwt_mean = df["bwt"].mean()
    bwt_std = df["bwt"].std()

    summary_results.append({
        "exp": exp,
        "acc": acc_mean,
        "acc_std": acc_std,
        "bwt": bwt_mean,
        "bwt_std": bwt_std,
    })

# 最终结果：每个 exp 一行，包含 acc/bwt 的平均与 acc 的标准差
final_df = pd.DataFrame(summary_results)
print(final_df)

In [None]:
# 定义主目录和子实验路径
base_dir = Path("/data1/zihuanqiu/mingle/outputs")
exp_names = [
    "weight_average",
    "continual_task_arithmetic",
    "continual_ties_merging",
    "magmax",
    "consensus_ta",
    "opcm",
    "mingle",
    "mingle_star"
]


for exp in exp_names:
    version_root = base_dir / exp / "vit-l-14-TALL20"
    version_dirs = sorted(version_root.glob("version_*"))
    if not version_dirs:
        print(f"⚠️ No version_* found in: {version_root}")
        continue

    # 聚合多个 version
    df = pd.concat([compute_task_acc(v) for v in version_dirs], ignore_index=True)
    # print(df)
    task_acc = df.mean(0)
    print(exp)
    print(task_acc)

In [None]:
# Load and compute mean gate outputs for each task and channel
base_dir = Path("/data1/zihuanqiu/mingle/outputs")

path = [
    "gate_state_wo_cons_0",
    "hyper_analyz_gamma_0",
    "hyper_analyz_gamma_1",
    "hyper_analyz_gamma_3",
    # "hyper_analyz_gamma_5",
]

tasks = ['EuroSAT', 'SVHN', 'MNIST', 'DTD', 'RESISC45', 'GTSRB', 'SUB397', 'Cars']
T = len(tasks)

# 创建2x4子图用于显示8个门的数据
fig, axes = plt.subplots(2, 4, figsize=(10, 4.5))
axes = axes.flatten()

# 存储所有gamma值的数据以确定全局y轴范围
all_means_global = []
all_stds_global = []

# 首先收集所有数据以计算全局y轴范围
for gamma_path in path:
    file_paths = {
        'EuroSAT': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_eurosat.npz',
        'SVHN': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_svhn.npz',
        'MNIST': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_mnist.npz',
        'DTD': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_dtd.npz',
        'RESISC45': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_resisc45.npz',
        'GTSRB': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_gtsrb.npz',
        'SUB397': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_sun397.npz',
        'Cars': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_stanford-cars.npz',
    }
    
    # 计算每个任务和通道的均值和标准差
    means = np.zeros((T, T))  # 存储均值
    stds = np.zeros((T, T))   # 存储标准差
    
    for i, task in enumerate(tasks):
        data = np.load(file_paths[task])
        arr = np.vstack([data[k] for k in sorted(data.files)])
        means[i, :] = arr.mean(axis=0)
        stds[i, :] = arr.std(axis=0)
    
    all_means_global.append(means.flatten())
    all_stds_global.append(stds.flatten())

# 计算全局y轴范围
all_means_global = np.concatenate(all_means_global)
all_stds_global = np.concatenate(all_stds_global)
max_std = np.max(all_stds_global)
y_min = -0.3
y_max = 1.3

# 创建一个空的handles和labels列表，用于最后的图例
all_handles = []
all_labels = []

# 现在绘制每个gamma值的数据
for gamma_idx, gamma_path in enumerate(path):
    file_paths = {
        'EuroSAT': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_eurosat.npz',
        'SVHN': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_svhn.npz',
        'MNIST': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_mnist.npz',
        'DTD': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_dtd.npz',
        'RESISC45': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_resisc45.npz',
        'GTSRB': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_gtsrb.npz',
        'SUB397': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_sun397.npz',
        'Cars': base_dir/gamma_path/'vit-b-16-TA8/version_0/7_gate_stats_stanford-cars.npz',
    }
    
    means = np.zeros((T, T))
    stds = np.zeros((T, T))
    
    for i, task in enumerate(tasks):
        data = np.load(file_paths[task])
        arr = np.vstack([data[k] for k in sorted(data.files)])
        means[i, :] = arr.mean(axis=0)
        stds[i, :] = arr.std(axis=0)
    
    # 获取当前gamma对应的颜色
    current_color = COLORS[gamma_idx]
    gamma_value = gamma_path.split("_")[-1]
    
    for j in range(T):
        ax = axes[j]
        
        # 获取当前门的数据
        gate_values = means[:, j]
        gate_stds = stds[:, j]
        
        x = np.array(range(T))
        
        # 绘制标准差范围（透明度较低以便看清多条线）
        ax.fill_between(x, gate_values - gate_stds, gate_values + gate_stds, 
                       color=current_color, alpha=0.1)
        
        # 绘制均值线和点
        line, = ax.plot(x, gate_values, marker='o', markersize=2,
                      color=current_color, linewidth=1., zorder=3, 
                      label=rf'$\gamma=$ {gamma_value}')
        
        # 只为第一个子图收集图例信息
        if j == 0:
            all_handles.append(line)
            all_labels.append(rf'$\gamma=$ {gamma_value}')

# 设置每个子图的标题、标签和格式
for j in range(T):
    ax = axes[j]
    
    # 设置标题和标签
    ax.set_title(f'Gate {j+1}', fontsize=10)
    # ax.text(0.5, 0.8, f'Gate {j+1}', fontsize=11, ha='center', va='bottom', transform=ax.transAxes)

    ax.set_xticks(range(T))
    ax.set_xticklabels(tasks, rotation=45, ha='right', fontsize=9)
    ax.set_ylim(y_min, y_max)
    
    # 添加网格以提高可读性
    ax.grid(True, axis='y', linestyle='--', alpha=0.3)

    # if j % 4 == 0:  
    ax.set_ylabel('Gate Activation', fontsize=10)
    ax.spines["top"].set_visible(True)
    ax.spines["right"].set_visible(True)

    # 突出显示主对角线（任务匹配门的位置）
    if j < T:
        ax.axvline(x=j, color='gray', alpha=0.2, linestyle='-', linewidth=5)

all_labels[0] = f'w/o Null-Space'
# 为整个图形添加一个图例
fig.legend(all_handles, all_labels, loc='upper center', bbox_to_anchor=(0.5, 0.05), 
           ncol=len(path), frameon=False, fontsize=10, columnspacing=4.0)

plt.tight_layout()
plt.subplots_adjust(top=0.85, wspace=0.3, hspace=0.7)
plt.savefig("images/gate_comparison_T8.pdf", bbox_inches="tight")
plt.show()

In [None]:
# Load and compute mean gate outputs for each task and channel
base_dir = Path("/data1/zihuanqiu/mingle/outputs")

path = [
    "gate_state_wo_cons_0",
    "hyper_analyz_gamma_0",
    "hyper_analyz_gamma_1",
    "hyper_analyz_gamma_3",
]

tasks = ['Flowers102', 'STL10', 'DTD', 'MNIST', 'CIFAR100', 'OxfordIIITPet', 'GTSRB', 'RESISC45', 
'PCAM', 'EuroSAT', 'SVHN', 'SUN397', 'Cars', 'FER2013']


T = len(tasks)

# 创建2x4子图用于显示8个门的数据
fig, axes = plt.subplots(5, 3, figsize=(12, 10))
axes = axes.flatten()

# 存储所有gamma值的数据以确定全局y轴范围
all_means_global = []
all_stds_global = []

# 首先收集所有数据以计算全局y轴范围
for gamma_path in path:
    base = base_dir / gamma_path / 'vit-b-16-TALL14' / 'version_0'

    file_paths = {
        'Flowers102':    base / '13_gate_stats_oxford_flowers102.npz',
        'STL10':         base / '13_gate_stats_stl10.npz',
        'DTD':           base / '13_gate_stats_dtd.npz',
        'MNIST':         base / '13_gate_stats_mnist.npz',
        'CIFAR100':      base / '13_gate_stats_cifar100.npz',
        'OxfordIIITPet': base / '13_gate_stats_oxford-iiit-pet.npz',
        'GTSRB':         base / '13_gate_stats_gtsrb.npz',
        'RESISC45':      base / '13_gate_stats_resisc45.npz',
        'PCAM':          base / '13_gate_stats_pcam.npz',
        'EuroSAT':       base / '13_gate_stats_eurosat.npz',
        'SVHN':          base / '13_gate_stats_svhn.npz',
        'SUN397':        base / '13_gate_stats_sun397.npz',
        'Cars':          base / '13_gate_stats_stanford-cars.npz',
        'FER2013':       base / '13_gate_stats_fer2013.npz',
    }
    
    # 计算每个任务和通道的均值和标准差
    means = np.zeros((T, T))  # 存储均值
    stds = np.zeros((T, T))   # 存储标准差
    
    for i, task in enumerate(tasks):
        data = np.load(file_paths[task])
        arr = np.vstack([data[k] for k in sorted(data.files)])
        means[i, :] = arr.mean(axis=0)
        stds[i, :] = arr.std(axis=0)
    
    all_means_global.append(means.flatten())
    all_stds_global.append(stds.flatten())

# 计算全局y轴范围
all_means_global = np.concatenate(all_means_global)
all_stds_global = np.concatenate(all_stds_global)
max_std = np.max(all_stds_global)
y_min = -0.3
y_max = 1.3

# 创建一个空的handles和labels列表，用于最后的图例
all_handles = []
all_labels = []

# 现在绘制每个gamma值的数据
for gamma_idx, gamma_path in enumerate(path):
    base = base_dir / gamma_path / 'vit-b-16-TALL14' / 'version_0'

    file_paths = {
        'Flowers102':    base / '13_gate_stats_oxford_flowers102.npz',
        'STL10':         base / '13_gate_stats_stl10.npz',
        'DTD':           base / '13_gate_stats_dtd.npz',
        'MNIST':         base / '13_gate_stats_mnist.npz',
        'CIFAR100':      base / '13_gate_stats_cifar100.npz',
        'OxfordIIITPet': base / '13_gate_stats_oxford-iiit-pet.npz',
        'GTSRB':         base / '13_gate_stats_gtsrb.npz',
        'RESISC45':      base / '13_gate_stats_resisc45.npz',
        'PCAM':          base / '13_gate_stats_pcam.npz',
        'EuroSAT':       base / '13_gate_stats_eurosat.npz',
        'SVHN':          base / '13_gate_stats_svhn.npz',
        'SUN397':        base / '13_gate_stats_sun397.npz',
        'Cars':          base / '13_gate_stats_stanford-cars.npz',
        'FER2013':       base / '13_gate_stats_fer2013.npz',
    }
    
    means = np.zeros((T, T))
    stds = np.zeros((T, T))
    
    for i, task in enumerate(tasks):
        data = np.load(file_paths[task])
        arr = np.vstack([data[k] for k in sorted(data.files)])
        means[i, :] = arr.mean(axis=0)
        stds[i, :] = arr.std(axis=0)
    
    # 获取当前gamma对应的颜色
    current_color = COLORS[gamma_idx]
    gamma_value = gamma_path.split("_")[-1]
    
    for j in range(T):
        ax = axes[j]
        
        # 获取当前门的数据
        gate_values = means[:, j]
        gate_stds = stds[:, j]
        
        x = np.array(range(T))
        
        # 绘制标准差范围（透明度较低以便看清多条线）
        ax.fill_between(x, gate_values - gate_stds, gate_values + gate_stds, 
                       color=current_color, alpha=0.1)
        
        # 绘制均值线和点
        line, = ax.plot(x, gate_values, marker='o', markersize=2,
                      color=current_color, linewidth=1., zorder=3, 
                      label=rf'$\gamma=$ {gamma_value}')
        
        # 只为第一个子图收集图例信息
        if j == 0:
            all_handles.append(line)
            all_labels.append(rf'$\gamma=$ {gamma_value}')

# 设置每个子图的标题、标签和格式
for j in range(T):
    ax = axes[j]
    
    # 设置标题和标签
    ax.set_title(f'Gate {j+1}', fontsize=10)
    # ax.text(0.5, 0.8, f'Gate {j+1}', fontsize=11, ha='center', va='bottom', transform=ax.transAxes)

    ax.set_xticks(range(T))
    ax.set_xticklabels(tasks, rotation=45, ha='right', fontsize=9)
    ax.set_ylim(y_min, y_max)
    
    # 添加网格以提高可读性
    ax.grid(True, axis='y', linestyle='--', alpha=0.3)

    # if j % 4 == 0:  
    ax.set_ylabel('Gate Activation', fontsize=10)
    ax.spines["top"].set_visible(True)
    ax.spines["right"].set_visible(True)

    # 突出显示主对角线（任务匹配门的位置）
    if j < T:
        ax.axvline(x=j, color='gray', alpha=0.2, linestyle='-', linewidth=5)

for ax in axes[T:]:
    fig.delaxes(ax)

all_labels[0] = f'w/o Null-Space'
# 为整个图形添加一个图例
fig.legend(all_handles, all_labels, loc='upper center', bbox_to_anchor=(0.5, 0.0), 
           ncol=len(path), frameon=False, fontsize=13, columnspacing=4.0)

plt.tight_layout()
plt.subplots_adjust(top=0.85, wspace=0.2, hspace=0.8)
plt.savefig("images/gate_comparison_T14.pdf", bbox_inches="tight")
plt.show()

In [None]:
# Load and compute mean gate outputs for each task and channel
base_dir = Path("/data1/zihuanqiu/mingle/outputs")

path = [
    "gate_state_wo_cons_0",
    "hyper_analyz_gamma_0",
    "hyper_analyz_gamma_1",
    "hyper_analyz_gamma_3",
]

tasks = ['RenderedSST2', 'GTSRB', "CIFAR10", "SVHN",  "PCAM", "CIFAR100", "Food101",  "KMNIST",
"MNIST", "STL10",  "EMNIST",  "FER2013", 'Cars', 'OxfordIIITPet', 'RESISC45', "FashionMNIST", 
"DTD", "Flowers102", "SUN397", 'EuroSAT']



T = len(tasks)

fig, axes = plt.subplots(7, 3, figsize=(15, 20))
axes = axes.flatten()

all_means_global = []
all_stds_global = []

def task_to_filename(task: str):
    special = {
        'RenderedSST2': 'rendered-sst2',
        'OxfordIIITPet': 'oxford-iiit-pet',
        'FashionMNIST': 'fashion_mnist',
        'Flowers102': 'oxford_flowers102',
        'EuroSAT': 'eurosat',
        'CIFAR10': 'cifar10',
        'CIFAR100': 'cifar100',
        'SVHN': 'svhn',
        'PCAM': 'pcam',
        'GTSRB': 'gtsrb',
        'STL10': 'stl10',
        'EMNIST': 'emnist_letters',
        'FER2013': 'fer2013',
        'Cars': 'stanford-cars',
        'KMNIST': 'kmnist',
        'MNIST': 'mnist',
        'DTD': 'dtd',
        'SUN397': 'sun397',
        'RESISC45': 'resisc45',
        'Food101': 'food101'
    }
    return special.get(task, task).lower()

for gamma_path in path:
    base = base_dir / gamma_path / 'vit-b-16-TALL20' / 'version_0'
    
    file_paths = {
        task: base / f"19_gate_stats_{task_to_filename(task)}.npz"
        for task in tasks
    }
    
    means = np.zeros((T, T))  # 存储均值
    stds = np.zeros((T, T))   # 存储标准差
    
    for i, task in enumerate(tasks):
        data = np.load(file_paths[task])
        arr = np.vstack([data[k] for k in sorted(data.files)])
        means[i, :] = arr.mean(axis=0)
        stds[i, :] = arr.std(axis=0)
    
    all_means_global.append(means.flatten())
    all_stds_global.append(stds.flatten())

# 计算全局y轴范围
all_means_global = np.concatenate(all_means_global)
all_stds_global = np.concatenate(all_stds_global)
max_std = np.max(all_stds_global)
y_min = -0.3
y_max = 1.3

# 创建一个空的handles和labels列表，用于最后的图例
all_handles = []
all_labels = []

# 现在绘制每个gamma值的数据
for gamma_idx, gamma_path in enumerate(path):
    base = base_dir / gamma_path / 'vit-b-16-TALL20' / 'version_0'
    
    file_paths = {
        task: base / f"19_gate_stats_{task_to_filename(task)}.npz"
        for task in tasks
    }
    
    means = np.zeros((T, T))
    stds = np.zeros((T, T))
    
    for i, task in enumerate(tasks):
        data = np.load(file_paths[task])
        arr = np.vstack([data[k] for k in sorted(data.files)])
        means[i, :] = arr.mean(axis=0)
        stds[i, :] = arr.std(axis=0)
    
    # 获取当前gamma对应的颜色
    current_color = COLORS[gamma_idx]
    gamma_value = gamma_path.split("_")[-1]
    
    for j in range(T):
        ax = axes[j]
        
        # 获取当前门的数据
        gate_values = means[:, j]
        gate_stds = stds[:, j]
        
        x = np.array(range(T))
        
        # 绘制标准差范围（透明度较低以便看清多条线）
        ax.fill_between(x, gate_values - gate_stds, gate_values + gate_stds, 
                       color=current_color, alpha=0.1)
        
        # 绘制均值线和点
        line, = ax.plot(x, gate_values, marker='o', markersize=2,
                      color=current_color, linewidth=1., zorder=3, 
                      label=rf'$\gamma=$ {gamma_value}')
        
        # 只为第一个子图收集图例信息
        if j == 0:
            all_handles.append(line)
            all_labels.append(rf'$\gamma=$ {gamma_value}')

# 设置每个子图的标题、标签和格式
for j in range(T):
    ax = axes[j]
    
    # 设置标题和标签
    ax.set_title(f'Gate {j+1}', fontsize=10)
    # ax.text(0.5, 0.8, f'Gate {j+1}', fontsize=11, ha='center', va='bottom', transform=ax.transAxes)

    ax.set_xticks(range(T))
    ax.set_xticklabels(tasks, rotation=45, ha='right', fontsize=9)
    ax.set_ylim(y_min, y_max)
    
    # 添加网格以提高可读性
    ax.grid(True, axis='y', linestyle='--', alpha=0.3)

    # if j % 4 == 0:  
    ax.set_ylabel('Gate Activation', fontsize=10)
    ax.spines["top"].set_visible(True)
    ax.spines["right"].set_visible(True)

    # 突出显示主对角线（任务匹配门的位置）
    if j < T:
        ax.axvline(x=j, color='gray', alpha=0.2, linestyle='-', linewidth=5)

for ax in axes[T:]:
    fig.delaxes(ax)

all_labels[0] = f'w/o Null-Space'
# 为整个图形添加一个图例
fig.legend(all_handles, all_labels, loc='upper center', bbox_to_anchor=(0.5, 0.0), 
           ncol=len(path), frameon=False, fontsize=13, columnspacing=4.0)

plt.tight_layout()
plt.subplots_adjust(top=0.85, wspace=0.2, hspace=0.8)
plt.savefig("images/gate_comparison_T20.pdf", bbox_inches="tight")
plt.show()