### Script of DeepAnno16, including: 
1. dataset construction
2. training procedures
3. comparison with V-xtractor

In [1]:
import os
import sys
import json
import pandas as pd
import numpy as np
from scipy.stats import wilcoxon, entropy, ranksums, norm, mannwhitneyu

from sklearn.metrics import confusion_matrix
import itertools
from sklearn import metrics
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import pandas as pd
from tqdm import tqdm
from Bio.Align.Applications import MuscleCommandline
from collections import Counter
from scipy.ndimage.filters import gaussian_filter1d

ori_data_dir = 'data/SILVA_data'
crossVal_data_dir = 'data/Ten_CrossValidation'
plt.rcParams["font.family"] = "Arial"

  from scipy.ndimage.filters import gaussian_filter1d


In [None]:
import shutil
import matplotlib
shutil.rmtree(matplotlib.get_cachedir())

1. Training data set is divided according to 8:1:1 ratio, and each of the ten data sets is used as a test set for cross-validation.

In [68]:
## SILVA dataset preprocessing
SILVA_metainfo = pd.read_csv(os.path.join(ori_data_dir, 'silva_metainfo.csv'),)
SILVA_seqs = pd.read_csv(os.path.join(ori_data_dir, 'silva_seqs_id_only.csv'),)
print(f'SILVA dataset contains {SILVA_metainfo.shape, SILVA_seqs.shape} seqs.')

SILVA_metainfo_clear = SILVA_metainfo.copy()
SILVA_metainfo_clear['Genus'] = SILVA_metainfo_clear['Genus'].apply(lambda x: x.lower())
SILVA_metainfo_clear['Species'] = SILVA_metainfo_clear['Species'].apply(lambda x: x.lower())
# remove plasmid and phage seqs
SILVA_metainfo_clear = SILVA_metainfo_clear[SILVA_metainfo_clear['Species'].apply(lambda x: ('plasmid' not in x) and ('phage' not in x))]
SILVA_metainfo_clear = SILVA_metainfo_clear[SILVA_metainfo_clear['Genus'].apply(lambda x: ('plasmid' not in x) and ('phage' not in x))]
SILVA_metainfo_clear.reset_index(inplace=True, drop=True)

SILVA_metainfo = SILVA_metainfo[SILVA_metainfo['silva_id'].isin(list(SILVA_metainfo_clear['silva_id']))]
SILVA_metainfo.reset_index(inplace=True, drop=True)
SILVA_metainfo.to_csv(os.path.join(ori_data_dir, 'silva_metainfo_clear.csv'), index=False)
print(f'SILVA dataset contains {SILVA_metainfo.shape} bacterial seqs.')

SILVA_metainfo['Genus'].value_counts()

In [None]:
## Get the golden standard of V-regions
from HVR_info_split_clear import *
def str_list(str = '"[1204, 1224]"'):
    str = str.strip('"[').strip(']"')
    str = str.split(', ')
    str = [int(x) for x in str]
    return str

def str_lens(x):
    x = str_list(x)
    return x[1] - x[0]


HVR_info_merged = ori_data_dir + '/HVRs_info' + '_merged.csv'
HVR_info_split_main(HVR_info_merged)

# further remove those aligned in a wrong way
ori_df = os.path.join(ori_data_dir, 'HVRs_info_merged_complete.csv')
ori_df = pd.read_csv(ori_df)
ori_df = pd.DataFrame(ori_df, columns=([f'v{i+1}' for i in range(9)] + ['lens', 'id']))
print(ori_df.shape)

for col in ori_df.columns[: -2]:
    # seqs of each V-region should meet the length constraint
    ori_df[col] = ori_df[col].apply(str_lens)
    t_start = ori_df.shape[0]

    if col == 'v9':
        key_v9 = ori_df[ori_df['v9'] <= 0]
        continue
    ori_df = ori_df[ori_df[col] > 0]
    if col != 'v1' and col != 'v9':
        ori_df = ori_df[ori_df[col] < 400]
    ori_df.reset_index(inplace=True, drop=True)
    print(f'Loss because of {col} is {t_start - ori_df.shape[0]}')

t_start = ori_df.shape[0]
ori_df = ori_df[(ori_df['v1'] + ori_df['v2'] + ori_df['v3'] + ori_df['v4'] + ori_df['v5'] + ori_df['v6'] + ori_df['v7'] + ori_df['v8'] + ori_df['v9'])  < ori_df['lens']]
ori_df.reset_index(inplace=True, drop=True)
print(f'Loss because of sum is {t_start - ori_df.shape[0]}')
clear_df = ori_df.copy()

ori_df = os.path.join(ori_data_dir, 'HVRs_info_merged_complete.csv')
ori_df = pd.read_csv(ori_df)

# remove seqs of plasmid and phage
ori_df = ori_df[ori_df['id'].isin(list(clear_df['id'])) & ori_df['id'].isin(list(SILVA_metainfo['silva_id_wrong']))]
print(f'{ori_df.shape} seqs with complete 9 V-regions annotation.')
ori_df.to_csv(os.path.join(ori_data_dir, 'HVRs_info_merged_complete_clear.csv'), index=False)

# replace incomplete seqs
ori_com_df = pd.read_csv(os.path.join(ori_data_dir, 'HVRs_info_merged_complete.csv'))
ori_incom_df = pd.read_csv(os.path.join(ori_data_dir, 'HVRs_info_merged_incomplete.csv'))
ori_com_incom_df = pd.concat([ori_com_df, ori_incom_df])
ori_com_incom_df.reset_index(inplace=True, drop=True)

ori_incom_df = ori_com_incom_df[(~ori_com_incom_df['id'].isin(list(ori_df['id']))) & ori_com_incom_df['id'].isin(list(SILVA_metainfo['silva_id_wrong']))]
ori_incom_df.reset_index(inplace=True, drop=True)
ori_incom_df.to_csv(os.path.join(ori_data_dir, 'HVRs_info_merged_incomplete_clear.csv'), index=False)
print(f'{ori_incom_df.shape} seqs with incomplete 9 V-regions annotation.')

# remove temperory files
os.remove(os.path.join(ori_data_dir, 'HVRs_info_merged_complete.csv'))
os.remove(os.path.join(ori_data_dir, 'HVRs_info_merged_incomplete.csv'))

(239175, 11)
Loss because of v1 is 0
Loss because of v2 is 8
Loss because of v3 is 8
Loss because of v4 is 5
Loss because of v5 is 7
Loss because of v6 is 6
Loss because of v7 is 24
Loss because of v8 is 57
Loss because of sum is 0
(238890, 12) seqs with complete 9 V-regions annotation.
(192498, 12) seqs with incomplete 9 V-regions annotation.


In [None]:
## 10-fold cross validation dataset
SILVA_metainfo_idRefine = pd.DataFrame(SILVA_metainfo, columns=['silva_id_wrong','silva_id'])
clear_ori_with_9hvrs = os.path.join(ori_data_dir, 'HVRs_info_merged_complete_clear.csv')
clear_ori_with_9hvrs = pd.read_csv(clear_ori_with_9hvrs)

# refine the silva id
clear_ori_with_9hvrs_clearPhagePlasmid = clear_ori_with_9hvrs.copy()
clear_ori_with_9hvrs_clearPhagePlasmid = pd.merge(clear_ori_with_9hvrs_clearPhagePlasmid, SILVA_metainfo_idRefine, how='left', left_on='id', right_on='silva_id_wrong')
clear_ori_with_9hvrs_clearPhagePlasmid['id'] = clear_ori_with_9hvrs_clearPhagePlasmid['silva_id']
clear_ori_with_9hvrs_clearPhagePlasmid.drop(["silva_id", "silva_id_wrong"], axis=1, inplace=True)
clear_ori_with_9hvrs_clearPhagePlasmid.to_csv(os.path.join(crossVal_data_dir, 'HVRs_info_merged_complete_clearBac.csv'), index=False)

print(f'All the training and testing set have {clear_ori_with_9hvrs_clearPhagePlasmid.shape} samples.')

# split training and testing set
train_test_ratio = 0.1
test_n = int(clear_ori_with_9hvrs_clearPhagePlasmid.shape[0] * train_test_ratio)
data_n = clear_ori_with_9hvrs_clearPhagePlasmid.shape[0]

df = clear_ori_with_9hvrs_clearPhagePlasmid.sample(frac=1, random_state=9931).reset_index(drop=True)
for i in range(10):
    start_index = i * test_n
    end_index = (i + 1) * test_n if i < 9 else data_n
    
    test_df = df.iloc[start_index:end_index]
    train_df = pd.concat([df.iloc[0:start_index], df.iloc[end_index:data_n]])

    train_df.to_csv(os.path.join(crossVal_data_dir, f"trainingSet_{i}Fold.csv"), index=False)
    test_df.to_csv(os.path.join(crossVal_data_dir, f"testingSet_{i}Fold.csv"), index=False)

All the training and testing set have (238890, 12) samples.


2. Training procedure

In [None]:
## Ten-fold cross validation
# config file
ori_json = '16sDeepSeg/config_train.json'
with open(ori_json, "r") as file:
    ori_json = json.load(file)
    
def create_train_config(ori_json, data_dir = '', save_dir = ''):
    json_config = ori_json.copy()
    json_config['data_loader']['args']['data_dir'] = data_dir
    json_config['trainer']['save_dir'] = save_dir
    # 修改log config的位置
    json_config['log_config'] = '16sDeepSeg/logger/logger_config.json'
    return json_config

with open( f'run_ten_Fold.sh', 'w') as f:
    for fold_i in range(10):
        fold_dir = os.path.join(crossVal_data_dir, f'Fold_{fold_i}')
        os.makedirs(fold_dir, exist_ok=True)
        train_data = os.path.join(crossVal_data_dir, f"trainingSet_{fold_i}Fold.csv")
        save_path = os.path.join(fold_dir, f"trainingSet_{fold_i}Fold")
        config_i = create_train_config(ori_json, data_dir = train_data, save_dir = save_path)
        
        # 保存相应的json文件
        out_json_path = os.path.join(fold_dir, f"config_train_{fold_i}Fold.json")
        with open(out_json_path, "w") as file:
            json.dump(config_i, file, indent=4)
        
        # 生成跑模型的命令
        model_train_py = '/data1/hyzhang/Projects/EcoPrimer_git/DeepEcoPrimer/16sDeepSeg/train.py' # 用绝对路径
        run_cmd = f'python {model_train_py} --config {out_json_path}'

        f.write(f'{run_cmd}\n') 

In [None]:
## 用于画图的函数
def plot_auc_pr(y_true, y_pred, title="", fig_name = './AUC.jpg'):
    # roc auc
    auc_scores = []
    
    plt.figure()
    lw = 2
    label_names = ['Conserved'] + [f'V{i+1}' for i in range(9)]
    for i in range(10):
        y_label = np.array(y_true == i, dtype=np.int8)
        y_pred_ = np.array(y_pred[:, i])   
        fpr = dict()
        tpr = dict() 
        roc_auc = dict()
        # calculate the auc
        auc_score = metrics.roc_auc_score(y_label, y_pred_)
        auc_scores.append(auc_score)
        
        # calculate the ROC curve
        fpr[0], tpr[0], _ = metrics.roc_curve(y_label, y_pred_)
        roc_auc[0] = metrics.auc(fpr[0], tpr[0])
        plt.plot(fpr[0], tpr[0],
            lw=lw, label= label_names[i] + ' (AUC = %0.3f)' % roc_auc[0]) # 16sDeepSeg Position-wise
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    fontsize = 14
    plt.xlabel('False Positive Rate', fontsize = fontsize)
    plt.ylabel('True Positive Rate', fontsize = fontsize)
    #plt.title('Receiver Operating Characteristic Curve', fontsize = fontsize)
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig(fig_name , dpi = 600)
    
    # pr auc
    plt.figure()
    plt.plot([0,1], [1, 0], linestyle='--')
    for i in range(10):
        y_label = np.array(y_true == i, dtype=np.int8)
        y_pred_ = np.array(y_pred[:, i]) 
        lr_precision, lr_recall, _ = metrics.precision_recall_curve(y_label, y_pred_)    
        plt.plot(lr_recall, lr_precision, lw = 2, label= label_names[i] + ' (AP = %0.3f)' % metrics.average_precision_score(y_label, y_pred_))
    fontsize = 14
    plt.xlabel('Recall', fontsize = fontsize)
    plt.ylabel('Precision', fontsize = fontsize)
    plt.title('Precision Recall Curve')
    plt.legend()
    plt.tight_layout()
    plt.savefig(fig_name.replace('ROC', 'PR') , dpi = 600)
    
    return(auc_scores)

def plot_recall_precision(
    cm,
    target_names,
    fig_path="",
    auc_scores = [],
):
    recalls = []
    precisions = []
    f1_scores = []
    for i in range(10):
        recall = cm[i, i] / cm[i, :].sum()
        precision = cm[i, i] / cm[:, i].sum()
        recalls.append(recall)
        precisions.append(precision)
        
        # 计算f1 score
        f1_score = 2 * (recall * precision) / (recall + precision)
        f1_scores.append(f1_score)

    plt.figure(figsize=(4, 4)) # A4纸的一半
    plt.rcParams['axes.labelsize'] = 12 # 以磅为单位
    barWidth=0.19
    metrics_to_plot = {'ROC_auc': auc_scores, 'Recall': recalls, 'Precision': precisions, 'F1_score': f1_scores}
    # plot the bar chart
    for met_i, (met_nm, met_val) in enumerate(metrics_to_plot.items()):
        plt.bar([x+met_i*barWidth for x in range(10)], height=met_val, label=met_nm, width=barWidth,)
    # plot the grid and x,y ticks
    plt.grid(axis='y', linestyle='--')
    plt.xticks([x+3*barWidth/2 for x in range(10)], target_names, rotation=315)
    def percentage_formatter(x, pos):
        return f'{x*100:.4f}%'
    plt.gca().yaxis.set_major_formatter(FuncFormatter(percentage_formatter))
    plt.ylim([0.999, 1.0001])
    # add legend
    plt.legend(bbox_to_anchor=(1.001, 1), loc='upper left', borderaxespad=0.)
    plt.tight_layout()
    plt.savefig(fig_path, dpi = 600)

def plot_conf(
    cm,
    target_names,
    cmap="Blues",
    normalize=True,
    fig_path="",
    auc_scores = [],
):
    # plot the recall and precision fig
    plot_recall_precision(cm, target_names, fig_path.replace('confusion_matrix', 'recall_precision'), auc_scores = auc_scores)
    
    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap("Blues")

    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(4, 4))
    plt.imshow(cm, interpolation="nearest", cmap=cmap)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=315)
        plt.yticks(tick_marks, target_names)

    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(
                j,
                i,
                "{:0.2f}%".format(cm[i, j] * 100),
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black",
            )
        else:
            plt.text(
                j,
                i,
                "{:,}".format(cm[i, j]),
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black",
            )

    plt.tight_layout()
    plt.ylabel("True label", size=15)
    plt.xlabel("Predicted label", size=15)
    plt.savefig(fig_path, format="svg", bbox_inches="tight", dpi=600)
    plt.show()


In [None]:
## 测试16sDeepSeg模型base-level的性能
fig_dir = os.path.join(crossVal_data_dir, 'Saved_figs')

# 首先跑Vxtrator的结果
# !python run_Vxtractor.py

# 跑各个16sDeepSeg模型的结果
ori_json_test = '16sDeepSeg/config_test.json'
with open(ori_json_test, "r") as file:
    ori_json_test = json.load(file)
    
def create_test_config(ori_json, data_dir = '', save_dir = ''):
    json_config = ori_json.copy()
    json_config['data_loader']['args']['data_dir'] = data_dir
    json_config['trainer']['save_dir'] = save_dir
    # 修改log config的位置
    json_config['log_config'] = '16sDeepSeg/logger/logger_config.json'
    return json_config

def find_best_model(save_dir = 'data/Ten_CrossValidation/Fold_6/trainingSet_6Fold/models/16sRNA_seg_Unet/'):
    model_path = os.listdir(save_dir)[0]
    model_path = os.path.join(save_dir, model_path)
    model_best_path = os.path.join(model_path, 'model_best.pth')
    return model_best_path

with open( f'run_ten_Fold_test.sh', 'w') as f:
    for fold_i in range(10):
        fold_dir = os.path.join(crossVal_data_dir, f'Fold_{fold_i}')
        os.makedirs(fold_dir, exist_ok=True)
        train_data = os.path.join(crossVal_data_dir, f"testingSet_{fold_i}Fold.csv")
        save_path = os.path.join(fold_dir, f"testingSet_{fold_i}Fold")
        config_i = create_test_config(ori_json_test, data_dir = train_data, save_dir = save_path)
        
        # 保存相应的json文件
        out_json_path = os.path.join(fold_dir, f"config_test_{fold_i}Fold.json")
        with open(out_json_path, "w") as file:
            json.dump(config_i, file, indent=4)
        
        # 生成跑模型的命令
        model_train_py = '16sRedSeg/module_output.py' # 用绝对路径
        try:
            model_best_path = find_best_model(save_dir = os.path.join(fold_dir, f"trainingSet_{fold_i}Fold/models/16sRNA_seg_Unet/"))
        except:
            model_best_path = '----'
            
        output_csv = os.path.join(fold_dir, f"testing_set_segmentation_16sDeepSeg.csv")
        run_cmd = f'python {model_train_py} -c {out_json_path} -r {model_best_path} --input {train_data} --output {output_csv} --plot_auc ./confusion_matrix.svg'

        f.write(f'{run_cmd}\n') 
        
# 首先做supplementary fig 2
fold_dirs = [f'data/Ten_CrossValidation/Fold_{i}' for i in range(10)]
position_wise_true, position_wise_pred = [], []
for fold_dir in fold_dirs:
    try:
        position_wise_true_t, position_wise_pred_t = np.load(os.path.join(fold_dir, 'position_wise_true.npy'),), np.load(os.path.join(fold_dir, 'position_wise_pred.npy'),)
    except:
        continue
    position_wise_true.append(position_wise_true_t)
    position_wise_pred.append(position_wise_pred_t)
position_wise_true, position_wise_pred = np.concatenate(position_wise_true, axis=0), np.concatenate(position_wise_pred, axis=0)
conf_mat = confusion_matrix(y_true=position_wise_true, y_pred=np.argmax(position_wise_pred, axis=1))
conf_fig_path = os.path.join(fig_dir, 'confusion_matrix.svg')

auc_scores = plot_auc_pr(y_pred=position_wise_pred, y_true=position_wise_true, fig_name=conf_fig_path.replace('confusion_matrix', 'ROC'))
plot_conf(conf_mat, target_names = ['Conserved'] + [f'V{i+1}' for i in range(9)], fig_path=conf_fig_path, normalize=True, auc_scores=auc_scores)


In [2]:
## 用于计算IOU和划分精度的函数
# plot_conf(conf_mat, target_names = ['Conserved'] + [f'V{i+1}' for i in range(9)], fig_path=conf_fig_path, normalize=True, auc_scores=auc_scores)
def compute_iou(rec1, rec2):
    """
    computing IOU
    :param rec1: (y0, y1), which reflects (left, right)
    :param rec2: (y0, y1)
    :return: scala value of IOU
    """
    # computing area of each rectangles
    S_rec1 = (rec1[1] - rec1[0])
    S_rec2 = (rec2[1] - rec2[0])
    # computing the sum_area
    sum_area = S_rec1 + S_rec2
    # find the each edge of intersect rectangle
    left_line = max(rec1[0], rec2[0])
    right_line = min(rec1[1], rec2[1])
    # judge if there is an intersect
    if left_line >= right_line:
        return 0
    else:
        intersect = (right_line - left_line)
        return (intersect / (sum_area - intersect)) * 1.0

def str_list(str = '"[1204, 1224]"'):
    if 'notfound' in str or 'wrongorder' in str:
        # wrong pred from vxtractor
        return [0, 0]
    # pred from vxtractor
    if '[' not in str:
        str = str.split()[0]
        str = str.strip("'")
        str = str.split('-')
    # pred from 16sDeepSeg
    else:
        str = str.strip('"[').strip(']"')
        str = str.split(', ')
        
    str = [int(float(x)) for x in str]
    return str

def cal_iou_list(real, pred, log_file, acc_cut_off = 0.3):
    iou_score = []
    total_seq = len(real)
    none_seq = 0
    for reg1, reg2 in zip(real, pred):
        reg1 = str_list(reg1)
        reg2 = str_list(reg2)
        iou_t = compute_iou(reg1, reg2)
        if iou_t < acc_cut_off:
            none_seq += 1
        iou_score.append(iou_t)
    # print the result to log file
    print(f'Average iou: {np.mean(iou_score)}', file=log_file)
    print(f'Inaccurately annotated seqs: {none_seq / total_seq}', file=log_file) # iou小于cut_off的序列比例
    res_df = pd.DataFrame(iou_score)
    recall_t = 1 - none_seq / total_seq
    return res_df, np.mean(iou_score), recall_t
    
def cal_iou_df(real_df, pred_df, primer = 'p2', log_file = None, acc_cut_off = 0.5):
    query_df_rna = pd.merge(real_df, pred_df, on='id', how='inner')
    query_df_rna.reset_index(inplace=True, drop=True)
    # get the list to calculate the iou score
    pred_primer_regions = list(query_df_rna[primer])
    label_primer_regions = list(query_df_rna[primer + '_label'])
    
    res_df, iou_mean, recall_t = cal_iou_list(label_primer_regions, pred_primer_regions, log_file, acc_cut_off = acc_cut_off)
    # res_df.to_csv(f'{primer}.csv', index = False)
    return iou_mean, recall_t

# the func to do one fold testing
def test_one_fold(folds_root = crossVal_data_dir, fold_num = 0, acc_cut_off = 0.5):
    fold_dir = os.path.join(folds_root, f'Fold_{fold_num}')
    df_evaluation = []
    
    for model_checked_nm in ['16sDeepSeg', 'Vxtractor']:
        # check 16sDeepSeg & Vxtractor model
        log_file = open(f'{fold_dir}/log_{model_checked_nm}.txt', 'w')
        goldenLabel_df = pd.read_csv(f'{folds_root}/testingSet_{fold_num}Fold.csv')
        if model_checked_nm == 'Vxtractor':
            pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
            cols_obj = ['Sequence'] + [f'V{i+1}' for i in range(9)]
            pred_df = pd.DataFrame(pred_df, columns=cols_obj)
            pred_df.rename({'Sequence': 'id'}, axis=1, inplace=True)
            pred_df['id'] = pred_df['id'].apply(lambda x: x.strip("'"))
            for i in range(9):
                pred_df.rename({f'V{i+1}': f'v{i+1}'}, axis=1, inplace=True)
        else:
            pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv')
        
        # get running time
        if model_checked_nm == 'Vxtractor':
            run_time = pd.read_csv('data/Ten_CrossValidation/Vxtractor_10Fold.csv')
            run_time_foldx = run_time.iloc[fold_num, 1]
        else:
            with open(f'{fold_dir}/16sDeepSeg_running_time.txt', 'r') as f:
                run_time_foldx = float(f.readline().strip())
        
        # calculate model efficacy
        pred_with_nineVregion_df = pred_df.copy()
        for i in range(9):
            if model_checked_nm == 'Vxtractor':
                pred_with_nineVregion_df = pred_with_nineVregion_df[pred_with_nineVregion_df[f'v{i+1}'].apply(lambda x: ('notfound' not in x) and ('wrongorder' not in x))]
            else:
                pred_with_nineVregion_df = pred_with_nineVregion_df[pred_with_nineVregion_df[f'v{i+1}'].apply(lambda x: ('-1' not in x))]
            pred_with_nineVregion_df.reset_index(inplace=True, drop=True)
        number_pred_with_nineVregion = pred_with_nineVregion_df.shape[0]
        print(f'Number of seqs with 9 V-regions: {number_pred_with_nineVregion}', file=log_file)
        
        # calculate the iou score
        print('Start iou cal.', file=log_file)
        iou_total = 0.0
        # calculate 9 v-regions
        df_t = {'Fold': fold_num, 'Methods': model_checked_nm, 'Running_time': run_time_foldx, 'Support_num': pred_df.shape[0], 'Efficacy_num': number_pred_with_nineVregion}
        for p_i in range(9):
            col = f'v{p_i+1}'
            with_primer_aligned = goldenLabel_df[~goldenLabel_df[col].isin(["[-1, -1]"])]
            print(f'Region {col}: {with_primer_aligned.shape} seqs with golden annotations.', file=log_file)
            with_primer_aligned.reset_index(inplace=True, drop=True)
            with_primer_aligned = pd.DataFrame(with_primer_aligned, columns=['id', col])
            with_primer_aligned.rename(columns = {col: col + '_label'}, inplace = True)
            # calculate iou score
            iou_mean, recall_t = cal_iou_df(with_primer_aligned, pred_df = pred_df, primer=col, log_file = log_file, acc_cut_off = acc_cut_off)
            iou_total += iou_mean
            df_t[col + '_iou_score'] = iou_mean
            df_t[col + '_recall'] = recall_t
        
        # calculate the mean iou score
        iou_total /= 9
        print(f'Total mean iou score is {iou_total}', file=log_file)
        log_file.close()
        df_evaluation.append(df_t)
    # df_evaluation = pd.DataFrame(df_evaluation)
    return df_evaluation


In [None]:
## 比较16sDeepSeg模型和V-xtractor的划分性能

# 生成table 1的结果
acc_cutoff = 0.5
df_16sDeepSeg_Vxtractor_eva = []
for foid_i in range(10):
    df_16sDeepSeg_Vxtractor_eva_foldx = test_one_fold(
        folds_root=crossVal_data_dir,
        fold_num=foid_i,
        acc_cut_off=acc_cutoff,
    )
    df_16sDeepSeg_Vxtractor_eva += df_16sDeepSeg_Vxtractor_eva_foldx
df_16sDeepSeg_Vxtractor_eva = pd.DataFrame(df_16sDeepSeg_Vxtractor_eva)
df_16sDeepSeg_Vxtractor_eva.to_csv(os.path.join(crossVal_data_dir, f'TenFoldCrossValidation_cutoff_{acc_cutoff}.csv'), index=False)

# save均值和方差信息
df_16sDeepSeg_Vxtractor_eva['Efficacy'] = df_16sDeepSeg_Vxtractor_eva['Efficacy_num'] / df_16sDeepSeg_Vxtractor_eva['Support_num']
writer = pd.ExcelWriter(f'data/Ten_CrossValidation/Saved_figs/TenFoldCrossValidation_cutoff_{acc_cutoff}.xlsx', engine="xlsxwriter")
df_16sDeepSeg_Vxtractor_eva.groupby('Methods').mean().to_excel(writer, sheet_name="Mean", )
df_16sDeepSeg_Vxtractor_eva.groupby('Methods').std().to_excel(writer, sheet_name="Std", )
writer.close()

  pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
  pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
  pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
  pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
  pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
  pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
  pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
  pred_df = p

In [None]:
# 添加pvalue等信息
df_16sDeepSeg_Vxtractor_eva_wilcox = df_16sDeepSeg_Vxtractor_eva.groupby('Methods').agg(lambda x: list(x))
df_16sDeepSeg_Vxtractor_eva_wilcox = df_16sDeepSeg_Vxtractor_eva_wilcox.transpose()

for test_ind in df_16sDeepSeg_Vxtractor_eva_wilcox.index:
    RedSeg_ls, Vxtractor_ls = df_16sDeepSeg_Vxtractor_eva_wilcox.loc[test_ind, '16sDeepSeg'], df_16sDeepSeg_Vxtractor_eva_wilcox.loc[test_ind, 'Vxtractor']
    # 执行Wilcoxon秩和检验
    try:
        statistic, p_value = wilcoxon(RedSeg_ls, Vxtractor_ls)
    except:
        statistic, p_value = -1, -1
    df_16sDeepSeg_Vxtractor_eva_wilcox.loc[test_ind, 'wilcox_pval'] = p_value

df_16sDeepSeg_Vxtractor_eva_wilcox.to_csv(f'data/Ten_CrossValidation/Saved_figs/TenFoldCrossValidation_wilcoxPval_cutoff_{acc_cutoff}.csv')



3. 重新换其他metrics查看Vxtractor和RedSeg划分结果：
* 先看看保守区长度情况(已经确定较长，试试看vx的结果是不是都覆盖了RedSeg结果
)
* 划分的保守区使用MSA看看具体保守性

In [2]:
deletion_cutoff = 0.1

def check_intersection(rec1, rec2):
    """
    计算一下是否Vxtractor更长的片段包含了RedSeg的短片段
    """
    # computing area of each rectangles
    S_rec1 = (rec1[1] - rec1[0])
    S_rec2 = (rec2[1] - rec2[0])
    # find the each edge of intersect rectangle
    left_line = max(rec1[0], rec2[0])
    right_line = min(rec1[1], rec2[1])
    if -1 in rec1 or -1 in rec2:
        return -1
    # judge if there is an intersect
    if left_line >= right_line:
        return 0
    else:
        intersect = (right_line - left_line)
        return (intersect / min(S_rec1, S_rec2)) * 1.0
    
def test_one_fold_intersection(folds_root = crossVal_data_dir, fold_num = 0,):
    # 20240425：重新计算一下两个模型的metrics
    fold_dir = os.path.join(folds_root, f'Fold_{fold_num}')
    os.makedirs(fold_dir, exist_ok=True)
    fold_dir_conserved = os.path.join(fold_dir, f'Conserved')
    os.makedirs(fold_dir_conserved, exist_ok=True)
    
    model_preds = []
    for model_checked_nm in ['16sDeepSeg', 'Vxtractor']:
        # check 16sDeepSeg & Vxtractor model
        if model_checked_nm == 'Vxtractor':
            pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
            cols_obj = ['Sequence'] + [f'V{i+1}' for i in range(9)]
            pred_df = pd.DataFrame(pred_df, columns=cols_obj)
            pred_df.rename({'Sequence': 'id'}, axis=1, inplace=True)
            pred_df['id'] = pred_df['id'].apply(lambda x: x.strip("'"))
            for i in range(9):
                pred_df.rename({f'V{i+1}': f'v{i+1}'}, axis=1, inplace=True)
                pred_df[f'v{i+1}'] = pred_df[f'v{i+1}'].apply(lambda x: [int(i.split(' HMM=')[0]) for i in x.strip("'").split('-')] if (('notfound' not in x) and ('wrongorder' not in x)) else [-1, -1])
        else:
            pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv')
            for i in range(9):
                pred_df[f'v{i+1}'] = pred_df[f'v{i+1}'].apply(lambda x: eval(x))
        
        # 获取conserved区域信息
        for conserve_i in range(8):
            pred_df[f'conserve_{conserve_i+2}_{model_checked_nm}'] = pred_df.apply(lambda x: [x[f'v{conserve_i+1}'][-1], x[f'v{conserve_i+2}'][0]] , axis = 1)
        model_preds.append(pred_df.loc[:, ['id'] + [f'conserve_{i}_{model_checked_nm}' for i in range(2, 10)]])
    
    model_preds = pd.merge(model_preds[0], model_preds[1], on='id', how='inner')
    
    return model_preds


In [4]:
# 此处基本确定Vxtractor覆盖了我们的区域，所以可以不用管这个
model_preds = test_one_fold_intersection()
for conserve_i in [2, 3, 4, 5, 6, 9]:
    model_preds[f'conserve_{conserve_i}_intersec'] = model_preds.apply(lambda x: check_intersection(x[f'conserve_{conserve_i}_Vxtractor'], x[f'conserve_{conserve_i}_16sDeepSeg']), axis = 1)
    print(f'Conserve {conserve_i}: ', model_preds.loc[model_preds[f'conserve_{conserve_i}_intersec'] >= 0, f'conserve_{conserve_i}_intersec'].mean())
model_preds

  pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)


Unnamed: 0,id,conserve_2_16sDeepSeg,conserve_3_16sDeepSeg,conserve_4_16sDeepSeg,conserve_5_16sDeepSeg,conserve_6_16sDeepSeg,conserve_7_16sDeepSeg,conserve_8_16sDeepSeg,conserve_9_16sDeepSeg,conserve_2_Vxtractor,conserve_3_Vxtractor,conserve_4_Vxtractor,conserve_5_Vxtractor,conserve_6_Vxtractor,conserve_7_Vxtractor,conserve_8_Vxtractor,conserve_9_Vxtractor
0,CP000302.1607157.1608689,"[95, 111]","[312, 355]","[496, 524]","[674, 801]","[902, 976]","[1037, 1105]","[1169, 1229]","[1379, 1399]","[91, 172]","[233, 408]","[470, 580]","[664, 813]","[854, 986]","[1036, 1106]","[1166, 1234]","[1290, 1401]"
1,JQ918082.1.1528,"[97, 116]","[334, 377]","[519, 547]","[696, 823]","[925, 999]","[1062, 1130]","[1194, 1254]","[1404, 1424]","[96, 178]","[255, 430]","[493, 603]","[686, 835]","[877, 1009]","[1061, 1131]","[1191, 1259]","[1315, 1426]"
2,CP018447.5249193.5250746,"[106, 122]","[323, 366]","[507, 535]","[685, 812]","[913, 987]","[1048, 1117]","[1179, 1239]","[1389, 1409]","[102, 183]","[244, 419]","[481, 591]","[675, 824]","[865, 997]","[1047, 1117]","[1176, 1244]","[1300, 1411]"
3,JN536939.1.1442,"[56, 74]","[279, 322]","[458, 486]","[638, 763]","[861, 935]","[995, 1063]","[1127, 1181]","[1337, 1357]","[55, 137]","[200, 375]","[432, 542]","[626, 775]","[813, 945]","[994, 1064]","[1124, 1192]","[1247, 1359]"
4,EU461376.1.1386,"[104, 123]","[327, 370]","[486, 514]","[664, 791]","[892, 966]","[1030, 1099]","[1155, 1209]","[1368, 1385]","[103, 185]","[248, 423]","[460, 570]","[654, 803]","[844, 976]","[1029, 1099]","[1152, 1220]","[1276, -1]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23884,EU772318.1.1374,"[88, 107]","[312, 355]","[474, 502]","[652, 779]","[880, 954]","[1018, 1086]","[1143, 1197]","[1353, 1373]","[87, 169]","[233, 408]","[448, 558]","[642, 791]","[832, 964]","[1017, 1087]","[1140, 1208]","[1264, -1]"
23885,AY786080.1.1478,"[81, 93]","[288, 323]","[439, 467]","[617, 744]","[844, 918]","[983, 1052]","[1114, 1174]","[1323, 1343]","[73, 154]","[201, 376]","[413, 523]","[607, 756]","[796, 928]","[982, 1052]","[1111, 1179]","[1234, 1345]"
23886,CP000448.417882.419570,"[226, 245]","[464, 507]","[649, 677]","[827, 954]","[1056, 1130]","[1191, 1259]","[1322, 1382]","[1532, 1552]","[225, 307]","[385, 560]","[623, 733]","[817, 966]","[1008, 1140]","[1190, 1260]","[1319, 1387]","[1443, 1554]"
23887,KJ572262.1.1357,"[26, 45]","[249, 286]","[427, 455]","[605, 732]","[826, 904]","[962, 1030]","[1094, 1154]","[1306, 1325]","[-1, 107]","[170, 344]","[401, 511]","[595, 744]","[782, 914]","[961, 1031]","[1091, 1159]","[1215, -1]"


In [2]:
# 20240428: 首先是把两个模型各自保守区的保守性比较做一个excel表
def generate_conserve_fa(pred_df, fasta_file_to, conserve_region = 'conserve_2'):
    sequences = []
    for _, row in pred_df.iterrows():
        seq_id = row['id']
        sequence = row['16s_rna']
        conserve_start, conserve_end = row[conserve_region]
        conserve_sequence = sequence[conserve_start-1:conserve_end]
        seq_record = SeqRecord(Seq(conserve_sequence), id=seq_id, description='')
        if (conserve_start != -1) and (conserve_end != -1):
            if len(conserve_sequence) < 10:
                continue # 太短了不能用
            sequences.append(seq_record)
        else:
            # 目前不处理标注为-1的情况，直接看看能标出来的解是否效果更好
            pass

    SeqIO.write(sequences, fasta_file_to, "fasta")
    # run muscle to multiple-alignment
    cline = MuscleCommandline(
        input=fasta_file_to,
        out=fasta_file_to.replace(".fa", "_afterMuscle.fa"),
    )
    return cline
    
def test_one_fold_v2(folds_root = crossVal_data_dir, fold_num = 0,):
    # 20240425：重新计算一下两个模型的metrics
    fold_dir = os.path.join(folds_root, f'Fold_{fold_num}')
    os.makedirs(fold_dir, exist_ok=True)
    fold_dir_conserved = os.path.join(fold_dir, f'Conserved')
    os.makedirs(fold_dir_conserved, exist_ok=True)
    goldenLabel_df = pd.read_csv(f'{folds_root}/testingSet_{fold_num}Fold.csv').iloc[:, :3] # 只需要这里提供的seqs序列
    
    with open(f"{fold_dir_conserved}/run_muscle.bash", "w") as f:
        for model_checked_nm in ['16sDeepSeg', 'Vxtractor']:
            # check 16sDeepSeg & Vxtractor model
            if model_checked_nm == 'Vxtractor':
                pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv', skiprows=range(2), index_col=False)
                cols_obj = ['Sequence'] + [f'V{i+1}' for i in range(9)]
                pred_df = pd.DataFrame(pred_df, columns=cols_obj)
                pred_df.rename({'Sequence': 'id'}, axis=1, inplace=True)
                pred_df['id'] = pred_df['id'].apply(lambda x: x.strip("'"))
                for i in range(9):
                    pred_df.rename({f'V{i+1}': f'v{i+1}'}, axis=1, inplace=True)
                    pred_df[f'v{i+1}'] = pred_df[f'v{i+1}'].apply(lambda x: [int(i.split(' HMM=')[0]) for i in x.strip("'").split('-')] if (('notfound' not in x) and ('wrongorder' not in x)) else [-1, -1])
            else:
                pred_df = pd.read_csv(f'{folds_root}/Fold_{fold_num}/testing_set_segmentation_{model_checked_nm}.csv')
                for i in range(9):
                    pred_df[f'v{i+1}'] = pred_df[f'v{i+1}'].apply(lambda x: eval(x))
            
            # 获取conserved区域信息
            pred_df = pd.merge(pred_df, goldenLabel_df, on = 'id', how='left', )
            for conserve_i in range(8):
                pred_df[f'conserve_{conserve_i+2}'] = pred_df.apply(lambda x: [x[f'v{conserve_i+1}'][-1], x[f'v{conserve_i+2}'][0]] , axis = 1)
                cline = generate_conserve_fa(pred_df, f'{fold_dir_conserved}/{model_checked_nm}_conserve_{conserve_i+2}.fa', conserve_region = f'conserve_{conserve_i+2}')
                
                cline = str(cline).replace("-in", "-align").replace("-out", "-output")
                cline = cline.replace("-align", "-super5") + " -threads 32"
                f.write(str(cline) + ' &\n')
        f.write('wait')
    
    os.system("bash " + f"{fold_dir_conserved}/run_muscle.bash")
    return

# 跑完了MSA，计算保守性结果的funcs
degenerate_base_table = {
    "A": ["A"],
    "T": ["T"],
    "G": ["G"],
    "C": ["C"],
    "Y": ["C", "T"],
    "R": ["A", "G"],
    "W": ["A", "T"],
    "S": ["G", "C"],
    "K": ["T", "G"],
    "M": ["C", "A"],
    "N": ["A", "T", "G", "C"],
    "H": ["A", "C", "T"],
    "V": ["A", "C", "G"],
    "B": ["C", "G", "T"],
    "D": ["A", "G", "T"],
}

def parse_MSA(
    mas_fna, 
    deletion_cutoff=0.9,  # 超过多少频率的-会被认为是MSA的不准确导致
    ):
    if not os.path.exists(mas_fna):
        return None, 'Error', 'Error'
    parsed_file = mas_fna.replace('afterMuscle.fa', 'parsedMSA.csv')
    recs = SeqIO.parse(mas_fna, "fasta")
    total_seq_num = len(list(recs))
    if not os.path.exists(parsed_file):
        seqs_np = []
        for rec in recs:
            seq_now = str(rec.seq)
            # 把其他dege的base换成一个普通base
            for dege_base, subs_nor_base in degenerate_base_table.items():
                subs_nor_base = subs_nor_base[
                    0
                ]  # 1219：这部分因为使用的是SILVA中的参考序列，这种dege base比例本身不高，参考师兄们的说法就处理成第一个base就行
                seq_now = seq_now.replace(dege_base, subs_nor_base)
            seqs_np.append(list(seq_now))
        seqs_np = np.array(seqs_np)
        
        # 统计每个位点的atgc分布情况
        total_seq_num = seqs_np.shape[0]  # 总共有多少条序列参与了引物设计
        pos_conserved = [{"A": -1, "T": -1, "G": -1, "C": -1, "-": -1}]
        for i in range(seqs_np.shape[1]):
            seqs_this_pos = list(seqs_np[:, i])
            atgc_cnt = Counter(seqs_this_pos)
            pos_conserved.append(dict(atgc_cnt))
        pos_conserved = pd.DataFrame(pos_conserved)
        pos_conserved.fillna(0.0, inplace=True)
        pos_conserved = pos_conserved.iloc[1:, :] # 去掉第一行的-1
        pos_conserved.reset_index(inplace=True, drop=True)
        pos_conserved = pos_conserved.div(total_seq_num, axis = 1)
    else:
        pos_conserved = pd.read_csv(parsed_file)
    
    # 注意最开始得到的是按照deletion<0.9进行处理的，因此这里如果新设定deletion_cutoff必须比0.9更小
    pos_conserved = pos_conserved.loc[pos_conserved['-'] < deletion_cutoff, :].reset_index(drop=True) # 太多的deletion位点就去掉
    
    entro_list = [entropy(pos_i, base = np.e) for pos_i in pos_conserved.values]
    entro_mean = np.mean(entro_list)
    conser_lens = len(entro_list)
    return pos_conserved, entro_mean, conser_lens, total_seq_num

def test_one_fold_entropyRes_v2(folds_root = crossVal_data_dir, fold_num = 0,):
    fold_dir = os.path.join(folds_root, f'Fold_{fold_num}')
    fold_dir_conserved = os.path.join(fold_dir, f'Conserved')
    entro_df = []
    for model_checked_nm in ['16sDeepSeg', 'Vxtractor']:
        temp_res = {'Fold': fold_num, 'Methods': model_checked_nm}
        for conserve_i in range(8):
            pos_conserved, entro_mean, conser_lens, total_seq_num = parse_MSA(f'{fold_dir_conserved}/{model_checked_nm}_conserve_{conserve_i + 2}_afterMuscle.fa', deletion_cutoff = deletion_cutoff)
            # 第一遍存储过了，现在别存了 pos_conserved.to_csv(f'{fold_dir_conserved}/{model_checked_nm}_conserve_{conserve_i + 2}_parsedMSA.csv', index = False)
            temp_res[f'conserve_{conserve_i + 2}_entropy'] = entro_mean
            temp_res[f'conserve_{conserve_i + 2}_lens'] = conser_lens
            temp_res[f'conserve_{conserve_i + 2}_Nseqs'] = total_seq_num
        entro_df.append(temp_res)
    return pd.DataFrame(entro_df)

for fold_i in tqdm(range(0, 10)):
    # test_one_fold_v2(folds_root = crossVal_data_dir, fold_num = fold_i,) # 20240427 -- 这部分结果已经跑完了可以不用跑节省时间
    res_fold = test_one_fold_entropyRes_v2(folds_root = crossVal_data_dir, fold_num = fold_i,)
    res_fold.to_csv(f'{crossVal_data_dir}/Fold_{fold_i}/Conserved/res_temp.csv', index=False)

100%|██████████| 10/10 [03:00<00:00, 18.09s/it]


In [6]:
conservatism_df = []
for fold_i in tqdm(range(0, 10)):
    temp_res = pd.read_csv(f'{crossVal_data_dir}/Fold_{fold_i}/Conserved/res_temp.csv',)
    conservatism_df.append(temp_res)
conservatism_df = pd.concat(conservatism_df, axis = 0).reset_index(drop =True)

conservatism_stat_df = []
for conserve_i in range(8):
    temp_res = {}
    temp_res[f'conserved_region'] = f'C{conserve_i + 2}'
    vx_cons, redseg_cons = list(conservatism_df.loc[conservatism_df['Methods'] == 'Vxtractor', f'conserve_{conserve_i+2}_entropy']), list(conservatism_df.loc[conservatism_df['Methods'] == '16sDeepSeg', f'conserve_{conserve_i+2}_entropy'])
    vx_lens, redseg_lens = list(conservatism_df.loc[conservatism_df['Methods'] == 'Vxtractor', f'conserve_{conserve_i+2}_lens']), list(conservatism_df.loc[conservatism_df['Methods'] == '16sDeepSeg', f'conserve_{conserve_i+2}_lens'])
    _, pvalue = wilcoxon(vx_cons, redseg_cons)
    
    temp_res[f'RedSeg_entropy'] = f'{np.mean(redseg_cons):.4f}±{np.std(redseg_cons):.4f} ({np.mean(redseg_lens):.2f}±{np.std(redseg_lens):.2f})'
    temp_res[f'Vxtractor_entropy'] = f'{np.mean(vx_cons):.4f}±{np.std(vx_cons):.4f} ({np.mean(vx_lens):.2f}±{np.std(vx_lens):.2f})'
    temp_res[f'Wilcox_paired_pvalue'] = pvalue
    conservatism_stat_df.append(temp_res)
conservatism_stat_df = pd.DataFrame(conservatism_stat_df)
conservatism_stat_df.to_csv(f'{crossVal_data_dir}/Saved_figs/Conserved_entropy_stat.csv', index=False, encoding='utf-8-sig')

100%|██████████| 10/10 [00:00<00:00, 127.09it/s]


In [4]:
# 20240427：对那些RedSeg预测出来比较短的conserve区域进行滑窗看是否属于最前列
def calculate_conservatism(vxtractor_file, redseg_file, deletion_cutoff=0.1):
    # 读取CSV文件
    vxtractor_data = pd.read_csv(vxtractor_file,)
    redseg_data = pd.read_csv(redseg_file,)
    vxtractor_data = vxtractor_data[vxtractor_data['-'] < deletion_cutoff].reset_index(drop=True)
    redseg_data = redseg_data[redseg_data['-'] < deletion_cutoff].reset_index(drop=True)
    
    # 确定窗口宽度为redseg文件的行数
    window_width = redseg_data.shape[0]

    # 计算滑窗的次数
    num_windows = vxtractor_data.shape[0] - window_width + 1
    conservatisms = []

    # 对于每个滑窗进行计算
    for i in range(num_windows):
        # 提取滑窗范围内的vxtractor数据
        window_data = vxtractor_data.iloc[i:i+window_width, :]
        window_conservatism = entropy(window_data.values.T, base=np.e).mean()
        conservatisms.append(window_conservatism)
    # 计算redseg数据的保守性
    redseg_conservatism = entropy(redseg_data.values.T, base=np.e).mean()
    
    # 判断redseg_conservatism是否在前列
    rank = sum([x < redseg_conservatism for x in conservatisms]) + 1
    conserved_than_ratio = sum([x >= redseg_conservatism for x in conservatisms]) / len(conservatisms)
    
    return conservatisms, [redseg_conservatism], rank, conserved_than_ratio

def stat_dist(vxtractor_conservatisms, redseg_conservatisms, ax = None, title = '', redseg_ranks = None, conserved_than_ratio = None):
    # 进行秩和检验
    _, p_value = mannwhitneyu(vxtractor_conservatisms, redseg_conservatisms)
    
    # 绘制保守性的直方图
    if ax is None:
        fig, ax = plt.subplots()
    # 使用 np.histogram 计算每个 bin 的分布
    hist, bin_edges = np.histogram(vxtractor_conservatisms, bins=30, density=True)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    ax.bar(bin_centers, hist, width=(bin_edges[1] - bin_edges[0]), alpha=0.6, edgecolor='black')
    ax.axvline(np.mean(redseg_conservatisms), color='r', linestyle='--', label='RedSeg', linewidth = 2)
    smoothed_hist = gaussian_filter1d(hist, sigma=1)
    ax.plot(bin_centers, smoothed_hist, color='black', linestyle='-', )
    
    ax.set_xlabel('Entropy of sliding windows of V-xtractor')
    ax.legend()
    ax.set_ylabel('Frequency of CRs between ' + title)
    text_x = np.mean(redseg_conservatisms) + (bin_edges[1] - bin_edges[0]) * 8  # 调整标记的位置
    y_lim = np.max(hist) * 1.4
    ax.set_ylim([0, y_lim])
    text_y = y_lim * 0.82  # 调整标记的高度
    # ax.text(text_x, text_y, f'Mean Rank: {np.mean(redseg_ranks):.2f}', ha='center')
    ax.text(text_x, text_y, f'{100 - np.mean(conserved_than_ratio) * 100:.2f}±{np.std(conserved_than_ratio) * 100:.2f}%', ha='center')
    
    return p_value, np.mean(redseg_ranks), np.mean(conserved_than_ratio)
    
# 创建子图
plt.rcParams['font.size'] = 9
fig, axs = plt.subplots(2, 3, figsize = (8.27, 5.5))
stat_res_df = []
fig_i = 0
for conserve_i in [2, 3, 4, 5, 6, 9]:
    vxtractor_conservatisms, redseg_conservatisms = [], []
    redseg_ranks = []
    conserved_than_ratio = []
    ax = axs[fig_i // 3, fig_i % 3]
    fig_i += 1
    for fold_i in range(1, 10):
        conserve_root = f'/data1/hyzhang/Projects/EcoPrimer_git/DeepEcoPrimer_v2/data/Ten_CrossValidation/Fold_{fold_i}/Conserved'
        vxtractor_file = f'{conserve_root}/Vxtractor_conserve_{conserve_i}_parsedMSA.csv'
        deepseg_file = f'{conserve_root}/16sDeepSeg_conserve_{conserve_i}_parsedMSA.csv'
        vx_cons, redseg_cons, rank, conserved_than_t = calculate_conservatism(vxtractor_file, deepseg_file, deletion_cutoff=deletion_cutoff)
        vxtractor_conservatisms += vx_cons
        redseg_conservatisms += redseg_cons
        redseg_ranks.append(rank)
        conserved_than_ratio.append(conserved_than_t)
    
    p_value, redseg_mean_rank, redseg_top_ratio = stat_dist(vxtractor_conservatisms, redseg_conservatisms, ax = ax, title = f'V{conserve_i-1}-V{conserve_i}', redseg_ranks = redseg_ranks, conserved_than_ratio = conserved_than_ratio)
    
plt.tight_layout()
fig.savefig(f'{crossVal_data_dir}/Saved_figs/conservatism_comparison.svg', dpi = 300)

In [11]:
fig.savefig(f'{crossVal_data_dir}/Saved_figs/conservatism_comparison.pdf', dpi = 300)

4. 重新利用完整数据训练一个RedSeg

In [None]:
publicated_model_dir = 'data/Publicated_16sDeepSeg'
ori_json = '16sDeepSeg/config_train.json'
with open(ori_json, "r") as file:
    ori_json = json.load(file)
    
def create_config_publicatedModel(ori_json, data_dir = '', save_dir = ''):
    json_config = ori_json.copy()
    json_config['data_loader']['args']['data_dir'] = data_dir
    json_config['trainer']['save_dir'] = save_dir
    # 修改log config的位置
    json_config['log_config'] = '16sDeepSeg/logger/logger_config.json'
    # 修改validation data size
    json_config['data_loader']['args']['validation_split'] = 0.02
    return json_config

train_data = os.path.join(crossVal_data_dir, f"HVRs_info_merged_complete_clearBac.csv")
save_path = publicated_model_dir
config_i = create_config_publicatedModel(ori_json, data_dir = train_data, save_dir = save_path)
# 保存相应的json文件
out_json_path = os.path.join(publicated_model_dir, f"config_train_publicatedModel.json")
with open(out_json_path, "w") as file:
    json.dump(config_i, file, indent=4)
# 生成跑模型的命令
model_train_py = '16sDeepSeg/train.py'
run_cmd = f'python {model_train_py} -c {out_json_path}'

print(f'{run_cmd}') 

python 16sDeepSeg/train.py -c data/Publicated_16sDeepSeg/config_train_publicatedModel.json


In [None]:
# 制作一个100条序列的demo文件用于16sDeepSeg的使用说明
demo_input = pd.read_csv('data/Ten_CrossValidation/testingSet_0Fold.csv')
demo_input = demo_input.iloc[:100, :3]
demo_input.to_csv('input/demo_input_16sDeepSeg.csv', index=False)

demo_input

Unnamed: 0,id,16s_rna,lens
0,CP000302.1607157.1608689,AGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCCTAACACA...,1533
1,JQ918082.1.1528,AGAGTTTGATCCTGGCTCAGGGCGAACGCTGGCGGCGTGCCTAACA...,1528
2,CP018447.5249193.5250746,ACTTAAATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGG...,1554
3,JN536939.1.1442,GGCAGGCTTACACATGCAAGTCGAGGGGCAGCAGATCATTTCGGTG...,1442
4,EU461376.1.1386,AGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCATGCCTAACA...,1386
...,...,...,...
95,JN038255.1.1495,AGAGTTTGATCATGGCTCAGGACGAACGCTGGCGGCGGGCTTAACA...,1495
96,JQ461008.1.1411,GAGTTTGATTCATGGCTCAAGACGAAACGCTGGCGGCGTGCCTAAT...,1411
97,JQ072447.1.1408,CAAGTCGAACGGCAGCACGGGCTTCGGCCTGGTGGCGAGTGGCGGA...,1408
98,CR628337.408822.410330,GAAGAGTTTGATCCTGGCTCAGATTGAACGCTGGCGGCATGCTTAA...,1509


In [None]:
# 制作一个所有silva序列的文件用于16sDeepSeg进行分割
silva_seg_input = pd.read_csv('/data1/hyzhang/Projects/EcoPrimer_git/DeepEcoPrimer_v2/data/SILVA_data/silva_seqs_id_only.csv')
silva_seg_input = pd.DataFrame(silva_seg_input, columns=['silva_id', '16s_rna', 'lens'])
silva_seg_input.rename(columns={'silva_id': 'id'}, inplace = True)
print(silva_seg_input.shape, silva_seg_input.head())

silva_seg_input.to_csv('/data1/hyzhang/Projects/EcoPrimer_git/DeepEcoPrimer_v2/data/SILVA_data/silva_input_16sDeepSeg.csv', index=False)

(431575, 3)                   id                                            16s_rna  lens
0    AB001445.1.1538  AACTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGC...  1538
1  KM209255.204.1909  AGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCCTAACA...  1706
2    HL281554.1.1313  GACGAACGCTGGCGGCGTGCTTAACACATGCAAGTCGAACGAGTGG...  1313
3    AB002515.1.1332  GCCTAATACATGCAAGTTGACGACAGATGATACGTAGCTTGCTACA...  1332
4    AB002523.1.1496  TCCTGGCTCAGGACGAACGCTGGCGGCGTGCCTAATACATGCAAGT...  1496
