In [1]:
import os
import numpy as np
import pandas as pd

tools=['mhcflurry_ba','mhcflurry_ps','netmhcpan_ba','netmhcpan_el','bigmhc','capsnetmhc_an','transphla','stmhcpan']
pep_len=[8,9,10,11]
path_result='/data1/wuguojia/data/mhc_benchmark/attentionbase/result/'
# 遍历每个工具
for tool in tools:
    # 为每个 pep_len 初始化对应的空列表
    data_lists = {length: [] for length in pep_len}
    tool_path = os.path.join(path_result, tool)
    npy_files = [f for f in os.listdir(tool_path) if f.endswith('.npy')]
    # 遍历每个文件
    for npy_file in npy_files:
        parts = npy_file.split('_')
        mode=parts[0]
        peptide=parts[1]
        allele=parts[2]
        length=len(peptide)
        # 如果肽段的长度在 pep_len 中，处理并加入对应的 data_list
        if length in pep_len:
            # 读取 .npy 文件的内容
            file_path = os.path.join(tool_path, npy_file)
            file_data = np.load(file_path)
            # 前四列
            row_data = [mode, allele, length, peptide]
            # 根据 mode 填充 pos 列
            if mode == 'LIME':
                pos_values = file_data[1:length + 1]  # 取文件数据的第2个数到第length+1个数
            elif mode == 'SHAP':
                pos_values = file_data[:length]  # 取前 length 个数
            # 如果数据不足以填满所有位置，则用 None 或 NaN 补齐
            pos_values = list(pos_values) + [None] * (length - len(pos_values))
            # 添加 pos 列的数据
            row_data.extend(pos_values)
            # 添加 tool 列数据
            row_data.append(tool)
            # 添加 bind_result_tool 和 bind_result_base 列（暂时空缺）
            row_data.extend([None, None])
            # 将当前行数据添加到对应的长度列表中
            data_lists[length].append(row_data)
    # 对于每个长度，生成 DataFrame 并保存为 CSV
    for length in pep_len:
        if data_lists[length]:  # 如果对应长度的列表不为空
            columns = ['mode', 'allele', 'length', 'peptide'] + [f'pos_{i+1}' for i in range(length)] + ['tool', 'bind_result_tool', 'bind_result_base']
            data_df = pd.DataFrame(data_lists[length], columns=columns)
            data_df.to_csv(f'{path_result}{tool}_length_{length}.csv', index=False)

In [2]:
# this part aims to add bind_result data
path_test_use='/data1/wuguojia/data/mhc_benchmark/attentionbase/testdata_use/'
test_files = [f for f in os.listdir(path_test_use) if f.endswith('.csv')]
test_list = []
for test_file in test_files:
    file_path = os.path.join(path_test_use, test_file)
    df = pd.read_csv(file_path)
    test_list.append(df)
test_df = pd.concat(test_list, ignore_index=True)#get test data matrix

result_files=[f for f in os.listdir(path_result) if f.endswith('.csv')]
for result_file in result_files:
    tool = result_file.split('_length')[0]  # 取文件名中 '_length' 之前的部分作为 tool
    file_path=os.path.join(path_result, result_file)
    df = pd.read_csv(file_path)
    # 遍历 df 的每一行，更新 bind_result_base 和 bind_result_tool 列
    for i, row in df.iterrows():
        # 检查 allele 和 peptide 列是否与 test_df 中 hlatype 和 antigen_peptide 列匹配
        match = (test_df['hlatype'] == row['allele']) & (test_df['antigen_peptide'] == row['peptide'])
        # 如果匹配，更新该行的 bind_result_base 和 bind_result_tool 列
        if match.any():  # 确保有匹配的行
            # 取出匹配的行对应的 bind_result
            bind_result_value = test_df.loc[match, 'bind_result'].values[0]
            tool_value = test_df.loc[match, tool].values[0]
            # 更新 df 中的 bind_result_base 和 bind_result_tool 列
            df.at[i, 'bind_result_base'] = bind_result_value
            df.at[i, 'bind_result_tool'] = tool_value  # tool 定义为从文件名中提取的部分
    # 对 df 按照 mode, allele, bind_result_base 列进行排序
    df = df.sort_values(by=['mode', 'allele', 'bind_result_base'])
    df.to_csv(file_path, index=False)

In [5]:
#this part aims to check unproduced data
result_files=[f for f in os.listdir(path_result) if f.endswith('.csv')]
for result_file in result_files:
    print(result_file)
    file_path=os.path.join(path_result, result_file)
    df = pd.read_csv(file_path)
    grouped_counts = df.groupby(['mode', 'allele']).size().reset_index(name='count')
    # 过滤出行数小于 100 的结果
    filtered_counts = grouped_counts[grouped_counts['count'] < 100]
    
    # 打印结果
    print(filtered_counts)

stmhcpan_length_9.csv
Empty DataFrame
Columns: [mode, allele, count]
Index: []
mhcflurry_ps_length_9.csv
Empty DataFrame
Columns: [mode, allele, count]
Index: []
bigmhc_length_9.csv
Empty DataFrame
Columns: [mode, allele, count]
Index: []
transphla_length_9.csv
Empty DataFrame
Columns: [mode, allele, count]
Index: []
netmhcpan_el_length_9.csv
Empty DataFrame
Columns: [mode, allele, count]
Index: []
netmhcpan_ba_length_9.csv
Empty DataFrame
Columns: [mode, allele, count]
Index: []
capsnetmhc_an_length_9.csv
Empty DataFrame
Columns: [mode, allele, count]
Index: []
mhcflurry_ba_length_9.csv
Empty DataFrame
Columns: [mode, allele, count]
Index: []


In [7]:
#prepare prediction
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import pickle, importlib
import argparse
import subprocess
import sys
import lime
import lime.lime_tabular
import shap
%matplotlib inline

In [8]:
# this part aims to add lack data in attentionbase
import os
import pandas as pd
#get test data matrix
path_test_use='/data1/wuguojia/data/mhc_benchmark/attentionbase/testdata_use/'
test_files = [f for f in os.listdir(path_test_use) if f.endswith('.csv')]
test_list = []
for test_file in test_files:
    file_path = os.path.join(path_test_use, test_file)
    df = pd.read_csv(file_path)
    test_list.append(df)
test_df = pd.concat(test_list, ignore_index=True)
df1=test_df[['hlatype','antigen_peptide','antigen_peptide_length']]
#get result data matrix
path_result='/data1/wuguojia/data/mhc_benchmark/attentionbase/result/'
result_files=[f for f in os.listdir(path_result) if f.endswith('.csv')]
df_list = [pd.read_csv(os.path.join(path_result, file)) for file in result_files]
merged_df = pd.concat(df_list, ignore_index=True)
df2=merged_df[['mode', 'allele', 'peptide', 'tool']]
#find lack data
all_results = []
for mode in df2['mode'].unique():
    for tool in df2['tool'].unique():
        df3 = df2[(df2['mode'] == mode) & (df2['tool'] == tool)]
        df3_renamed = df3.rename(columns={'allele': 'hlatype', 'peptide': 'antigen_peptide'})
        merged_df = pd.merge(df1, df3_renamed, on=['hlatype', 'antigen_peptide'], how='left', indicator=True)
        df1_extra = merged_df[merged_df['_merge'] == 'left_only']
        result = df1_extra[['hlatype', 'antigen_peptide']]
        result['mode']=mode
        result['tool']=tool
        result['length'] = result['antigen_peptide'].apply(len)
        result['hlatype_length'] = result['hlatype'] + '_' + result['length'].astype(str)
        result['position'] = None
        result['test_path']=None
        result['train_path']=None
        for i, row in result.iterrows():
            hlatype_length = row['hlatype_length']
            antigen_peptide = row['antigen_peptide']
            file_path = os.path.join(path_test_use, hlatype_length+'.csv')
            df = pd.read_csv(file_path)
            position = df[df['antigen_peptide'] == antigen_peptide].index.tolist()
            result.at[i, 'position'] = position[0]
            result.at[i, 'test_path'] = file_path
            #add train path
            path_train='/data1/wuguojia/data/mhc_benchmark/attentionbase/traindata_raw/'
            if mode=='LIME':
                result.at[i,'train_path']=os.path.join(path_train, hlatype_length+'.csv')
            if mode=='SHAP':
                result.at[i,'train_path']=os.path.join(path_train, hlatype_length+'.pkl')
        result=result[['mode','tool','test_path','train_path','position']]
        all_results.append(result)
result = pd.concat(all_results, ignore_index=True)
pd.set_option('display.max_colwidth', None)  # 不限制列宽，显示所有内容
print(result)

Empty DataFrame
Columns: [mode, tool, test_path, train_path, position]
Index: []


In [3]:
# 遍历 result 数据框的每一行
for index, row in result.iterrows():
    # 设置命令参数
    command = [
        "python", "/home/wuguojia/biocode/mhc_benchmark/MHCXAI.py",
        "--input_list", row['test_path'],
        "--index", str(row['position']),
        "--predictor", "mhcflurry",
        "--xai", row['mode'],
        "--mode", "affinity" if row['tool'] == "mhcflurry_ba" else "presentation_score",
        "--trainf_path", row['train_path'],
        "--dest", '/data1/wuguojia/data/mhc_benchmark/attentionbase/result/'
    ]
    # 运行命令
    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    # 打印结果
    print("SHAP binding affinity done for index:", index)

SHAP binding affinity done for index: 0
