In [1]:
import os
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
from dataclasses import dataclass
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

@dataclass
class PerformanceMetrics:
    """性能指标数据类"""
    threshold: float
    tp: int  # 真阳性
    fp: int  # 假阳性
    tn: int  # 真阴性
    fn: int  # 假阴性
    multi_pred_count: int  # 多预测计数
    constant_tn: int  # 恒定真阴性
    constant_fn: int  # 恒定假阴性
    
    @property
    def accuracy(self) -> float:
        """计算准确率"""
        total = self.tp + self.fp + self.tn + self.fn + self.constant_tn + self.constant_fn
        return (self.tp + self.tn + self.constant_tn) / total if total > 0 else 0
    
    @property
    def precision(self) -> float:
        """计算精确率"""
        denominator = self.tp + self.fp
        return self.tp / denominator if denominator > 0 else 0
    
    @property
    def recall(self) -> float:
        """计算召回率"""
        denominator = self.tp + self.fn + self.constant_fn
        return self.tp / denominator if denominator > 0 else 0
    
    @property
    def specificity(self) -> float:
        """计算特异性"""
        denominator = self.tn + self.fp + self.constant_tn
        return (self.tn + self.constant_tn) / denominator if denominator > 0 else 0
    
    @property
    def fpr(self) -> float:
        """计算假阳性率"""
        return 1 - self.specificity
    
    @property
    def f1_score(self) -> float:
        """计算F1分数"""
        prec = self.precision
        rec = self.recall
        denominator = prec + rec
        return 2 * (prec * rec) / denominator if denominator > 0 else 0

class ThresholdOptimizer:
    """阈值优化分析器"""
    
    def __init__(self, data_file: str, output_dir: str):
        """初始化优化器"""
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # 读取数据
        self.df = pd.read_csv(os.path.join(output_dir, data_file))
        
        # 数据验证和统计
        self._validate_and_print_stats()
        
        # 处理final_score
        self._process_final_scores()
        
        # 按transcript_id分组
        self.grouped_predictions = self.df.groupby('transcript_id')
    def _validate_and_print_stats(self):
        """验证数据并打印统计信息"""
        print("\nData Validation and Statistics:")
        print("=" * 50)
        
        # 原始转录本统计
        nm_transcripts = self.df['transcript_id'].str.startswith('NM_')
        nr_transcripts = self.df['transcript_id'].str.startswith('NR_')
        
        print("\n1. 原始数据统计:")
        print(f"总行数: {len(self.df)}")
        print(f"包含NM_的行数: {nm_transcripts.sum()}")
        print(f"包含NR_的行数: {nr_transcripts.sum()}")
        
        # 唯一转录本统计
        unique_nm = self.df[nm_transcripts]['transcript_id'].nunique()
        unique_nr = self.df[nr_transcripts]['transcript_id'].nunique()
        
        print("\n2. 唯一转录本统计:")
        print(f"唯一NM_转录本数: {unique_nm}")
        print(f"唯一NR_转录本数: {unique_nr}")
        
        # 缺失值统计
        print("\n3. 缺失值统计:")
        print(self.df.isnull().sum())
        
        # final_score分布
        print("\n4. final_score分布:")
        print(self.df['final_score'].describe())
        
        # 每个转录本的预测数量
        duplicates = self.df.groupby('transcript_id').size()
        print("\n5. 每个转录本的预测数量分布:")
        print(duplicates.value_counts().sort_index())

    def _process_final_scores(self):
        """处理final_score，包括NaN值处理"""
        # 转换为数值类型
        self.df['final_score'] = pd.to_numeric(self.df['final_score'], errors='coerce')
        
        # 统计NaN值分布
        nan_transcripts = self.df.groupby('transcript_id')['final_score'].apply(
            lambda x: x.isna().all())
        nan_transcripts = nan_transcripts[nan_transcripts]
        
        print("\nProcessing NaN values:")
        print(f"Total transcripts with all NaN scores: {len(nan_transcripts)}")
        print(f"NM transcripts: {sum(i.startswith('NM_') for i in nan_transcripts.index)}")
        print(f"NR transcripts: {sum(i.startswith('NR_') for i in nan_transcripts.index)}")
        
        # 为NaN值分配随机值
        np.random.seed(42)
        for transcript_id in nan_transcripts.index:
            random_score = np.random.uniform(0, 0.1)
            self.df.loc[self.df['transcript_id'] == transcript_id, 'final_score'] = random_score

    def _is_correct_prediction(self, row) -> bool:
        """判断单个预测是否正确"""
        return (pd.notna(row['true_tis']) and 
                pd.notna(row['true_tts']) and 
                row['tis_position'] == row['true_tis'] and 
                row['tts_position'] == row['true_tts'])

    def calculate_rates(self, metrics: PerformanceMetrics) -> Tuple[float, float]:
        """计算TPR和FPR，包含恒定值"""
        tpr = metrics.tp / (metrics.tp + metrics.fn + metrics.constant_fn)
        fpr = metrics.fp / (metrics.fp + metrics.tn + metrics.constant_tn)
        return fpr, tpr

    def evaluate_transcript_level(self, threshold: float) -> PerformanceMetrics:
        """转录本层面的评估"""
        tp = fp = tn = fn = constant_tn = constant_fn = 0
        multi_pred_count = 0
        processed_transcripts = set()
        
        # 首先处理全是NaN的转录本
        nan_transcripts = self.df.groupby('transcript_id')['final_score'].apply(
            lambda x: x.isna().all())
        
        for transcript_id, is_all_nan in nan_transcripts.items():
            if is_all_nan:
                if transcript_id.startswith('NM_'):
                    constant_fn += 1
                else:
                    constant_tn += 1
                processed_transcripts.add(transcript_id)
        
        # 处理其他转录本
        for transcript_id, group in self.grouped_predictions:
            if transcript_id in processed_transcripts:
                continue
                
            is_nm = transcript_id.startswith('NM_')
            valid_predictions = group[
                pd.notna(group['final_score']) & 
                (group['final_score'] > threshold)
            ]
            
            if len(valid_predictions) > 1:
                multi_pred_count += 1
            
            if is_nm:
                if len(valid_predictions) == 0:
                    fn += 1
                elif len(valid_predictions) == 1:
                    if self._is_correct_prediction(valid_predictions.iloc[0]):
                        tp += 1
                    else:
                        fp += 1
                else:
                    fp += 1
            else:
                if len(valid_predictions) > 0:
                    fp += 1
                else:
                    tn += 1
        
        return PerformanceMetrics(
            threshold=threshold,
            tp=tp,
            fp=fp,
            tn=tn,
            fn=fn,
            multi_pred_count=multi_pred_count,
            constant_tn=constant_tn,
            constant_fn=constant_fn
        )
        
    def evaluate_orf_level(self, threshold: float) -> PerformanceMetrics:
        """ORF层面的评估"""
        tp = fp = tn = fn = constant_tn = constant_fn = 0
        multi_pred_count = 0
        processed_transcripts = set()
        
        # 首先处理全是NaN的转录本
        nan_transcripts = self.df.groupby('transcript_id')['final_score'].apply(
            lambda x: x.isna().all())
        
        for transcript_id, is_all_nan in nan_transcripts.items():
            if is_all_nan:
                if transcript_id.startswith('NM_'):
                    constant_fn += 1
                else:
                    constant_tn += 1
                processed_transcripts.add(transcript_id)
        
        # 处理每个ORF预测
        for _, row in self.df.iterrows():
            if row['transcript_id'] in processed_transcripts:
                continue
                
            is_nm = row['transcript_id'].startswith('NM_')
            
            if pd.isna(row['final_score']):
                continue
                
            if row['final_score'] > threshold:
                if is_nm and self._is_correct_prediction(row):
                    tp += 1
                else:
                    fp += 1
            else:
                if is_nm:
                    fn += 1
                else:
                    tn += 1
        
        return PerformanceMetrics(
            threshold=threshold,
            tp=tp,
            fp=fp,
            tn=tn,
            fn=fn,
            multi_pred_count=multi_pred_count,
            constant_tn=constant_tn,
            constant_fn=constant_fn
        )

    def _plot_dual_roc_curves(self, transcript_results: List[PerformanceMetrics], 
                           orf_results: List[PerformanceMetrics]):
        """绘制双ROC曲线"""
        plt.figure(figsize=(4, 3), dpi=600)
        
        # 计算转录本层面ROC
        fpr_t = []
        tpr_t = []
        for metrics in transcript_results:
            fpr, tpr = self.calculate_rates(metrics)
            fpr_t.append(fpr)
            tpr_t.append(tpr)
        auc_t = np.trapz(tpr_t, fpr_t)
        
        # 计算ORF层面ROC
        fpr_o = []
        tpr_o = []
        for metrics in orf_results:
            fpr, tpr = self.calculate_rates(metrics)
            fpr_o.append(fpr)
            tpr_o.append(tpr)
        auc_o = np.trapz(tpr_o, fpr_o)
        
        # 绘制曲线
        plt.plot(fpr_t, tpr_t, 'b-', label=f'Transcript Level (AUC={auc_t:.3f})')
        plt.plot(fpr_o, tpr_o, 'r-', label=f'ORF Level (AUC={auc_o:.3f})')
        plt.plot([0, 1], [0, 1], 'k--', label='Random')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curves: Transcript vs ORF Level')
        plt.legend(loc='lower right')
        plt.grid(True,lw=0.5)
        plt.savefig(os.path.join(self.output_dir, 'dual_roc_curves.png'), dpi=600)
        plt.close()

    def _plot_performance_curves(self, results: List[PerformanceMetrics], level: str):
        """绘制性能曲线"""
        thresholds = [r.threshold for r in results]
        metrics = {
            'Accuracy': [r.accuracy for r in results],
            'Precision': [r.precision for r in results],
            'Recall': [r.recall for r in results],
            'F1 Score': [r.f1_score for r in results]
        }
        
        plt.figure(figsize=(12, 6))
        for metric_name, values in metrics.items():
            plt.plot(thresholds, values, label=metric_name)
        
        plt.xlabel('Threshold')
        plt.ylabel('Score')
        plt.title(f'{level.capitalize()} Level Performance Metrics')
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(self.output_dir, f'{level}_performance_curves.png'))
        plt.close()

    def _plot_transcript_scores(self, results: List[PerformanceMetrics]):
        """绘制转录本得分分布"""
        best_threshold = max(results, key=lambda x: x.f1_score).threshold
        
        transcript_data = {
            'NM': {'correct': [], 'incorrect': []},
            'NR': {'correct': [], 'incorrect': []}
        }
        
        for transcript_id, group in self.grouped_predictions:
            max_score = group['final_score'].max()
            is_nm = transcript_id.startswith('NM_')
            
            if is_nm:
                has_match = False
                if max_score > best_threshold:
                    for _, row in group.iterrows():
                        if self._is_correct_prediction(row):
                            has_match = True
                            break
                    if has_match:
                        transcript_data['NM']['correct'].append(max_score)
                    else:
                        transcript_data['NM']['incorrect'].append(max_score)
                else:
                    transcript_data['NM']['incorrect'].append(max_score)
            else:
                if max_score > best_threshold:
                    transcript_data['NR']['incorrect'].append(max_score)
                else:
                    transcript_data['NR']['correct'].append(max_score)
        
        plt.figure(figsize=(10, 6))
        
        # 设置抖动范围
        jitter_range = 0.2
        
        # 绘制NM转录本
        x_nm_correct = np.random.normal(0, jitter_range, len(transcript_data['NM']['correct']))
        x_nm_incorrect = np.random.normal(0, jitter_range, len(transcript_data['NM']['incorrect']))
        plt.scatter(x_nm_correct, transcript_data['NM']['correct'], 
                   c='blue', alpha=0.5, label='NM Correct', s=1)
        plt.scatter(x_nm_incorrect, transcript_data['NM']['incorrect'], 
                   c='red', alpha=0.5, label='NM Incorrect', s=1)
        
        # 绘制NR转录本
        x_nr_correct = np.random.normal(1, jitter_range, len(transcript_data['NR']['correct']))
        x_nr_incorrect = np.random.normal(1, jitter_range, len(transcript_data['NR']['incorrect']))
        plt.scatter(x_nr_correct, transcript_data['NR']['correct'], 
                   c='blue', alpha=0.5, s=1)
        plt.scatter(x_nr_incorrect, transcript_data['NR']['incorrect'], 
                   c='red', alpha=0.5, s=1)
        
        # 绘制最优阈值线
        plt.axhline(y=best_threshold, color='g', linestyle='--', 
                   label=f'Threshold ({best_threshold:.3f})')
        
        plt.xticks([0, 1], ['NM', 'NR'])
        plt.ylabel('Final Score')
        plt.title('Transcript Score Distribution')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        stats_text = (
            f"NM Correct: {len(transcript_data['NM']['correct'])}\n"
            f"NM Incorrect: {len(transcript_data['NM']['incorrect'])}\n"
            f"NR Correct: {len(transcript_data['NR']['correct'])}\n"
            f"NR Incorrect: {len(transcript_data['NR']['incorrect'])}"
        )
        plt.figtext(1.15, 0.5, stats_text, fontsize=10, 
                   bbox=dict(facecolor='white', alpha=0.8))
        
        plt.tight_layout()
        plt.subplots_adjust(right=0.85)
        plt.savefig(os.path.join(self.output_dir, 'transcript_scores.png'), 
                   bbox_inches='tight', dpi=300)
        plt.close()

    def _save_performance_data(self, results: List[PerformanceMetrics], output_file: str):
        """保存性能指标数据到文件"""
        with open(output_file, 'w') as f:
            f.write("Threshold\tTP\tFP\tTN\tFN\tConstant_TN\tConstant_FN\tMulti_Pred_Count\t"
                   "Accuracy\tPrecision\tRecall\tSpecificity\tFPR\tF1_Score\n")
            
            for metrics in results:
                f.write(f"{metrics.threshold:.3f}\t")
                f.write(f"{metrics.tp}\t")
                f.write(f"{metrics.fp}\t")
                f.write(f"{metrics.tn}\t")
                f.write(f"{metrics.fn}\t")
                f.write(f"{metrics.constant_tn}\t")
                f.write(f"{metrics.constant_fn}\t")
                f.write(f"{metrics.multi_pred_count}\t")
                f.write(f"{metrics.accuracy:.3f}\t")
                f.write(f"{metrics.precision:.3f}\t")
                f.write(f"{metrics.recall:.3f}\t")
                f.write(f"{metrics.specificity:.3f}\t")
                f.write(f"{metrics.fpr:.3f}\t")
                f.write(f"{metrics.f1_score:.3f}\n")

    def find_optimal_threshold(self, 
                             start: float = 0.0, 
                             end: float = 1.0, 
                             step: float = 0.01) -> Dict:
        """寻找最优阈值，同时计算转录本层面和ORF层面的性能"""
        thresholds = np.arange(start, end + step, step)
        
        transcript_results = []
        orf_results = []
        best_transcript_metrics = None
        best_orf_metrics = None
        best_transcript_f1 = -1
        best_orf_f1 = -1
        
        for threshold in thresholds:
            # 转录本层面评估
            transcript_metrics = self.evaluate_transcript_level(threshold)
            transcript_results.append(transcript_metrics)
            if transcript_metrics.f1_score > best_transcript_f1:
                best_transcript_f1 = transcript_metrics.f1_score
                best_transcript_metrics = transcript_metrics
            
            # ORF层面评估
            orf_metrics = self.evaluate_orf_level(threshold)
            orf_results.append(orf_metrics)
            if orf_metrics.f1_score > best_orf_f1:
                best_orf_f1 = orf_metrics.f1_score
                best_orf_metrics = orf_metrics
        
        # 保存性能指标数据
        self._save_performance_data(transcript_results, 
                                  os.path.join(self.output_dir, "transcript_performance.txt"))
        self._save_performance_data(orf_results, 
                                  os.path.join(self.output_dir, "orf_performance.txt"))
        
        # 绘制性能曲线
        self._plot_performance_curves(transcript_results, "transcript")
        self._plot_performance_curves(orf_results, "orf")
        self._plot_dual_roc_curves(transcript_results, orf_results)
        self._plot_transcript_scores(transcript_results)
        
        return {
            'transcript': {
                'threshold': best_transcript_metrics.threshold,
                'metrics': best_transcript_metrics
            },
            'orf': {
                'threshold': best_orf_metrics.threshold,
                'metrics': best_orf_metrics
            }
        }

def main():
    """主函数"""
    input_file = 'TRANSAID_Embedding_batch4_ALL_tis_tts_pairs.csv'
    file_dir = '/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/TTSTIS_selection'
    
    # 创建优化器
    optimizer = ThresholdOptimizer(input_file, file_dir)
    
    # 寻找最优阈值并获取两种层面的结果
    results = optimizer.find_optimal_threshold()
    
    # 打印转录本层面的结果
    print("\nTranscript Level Results:")
    print("=" * 50)
    metrics = results['transcript']['metrics']
    print(f"Best Threshold: {results['transcript']['threshold']:.3f}")
    print(f"Accuracy: {metrics.accuracy:.3f}")
    print(f"Precision: {metrics.precision:.3f}")
    print(f"Recall: {metrics.recall:.3f}")
    print(f"F1 Score: {metrics.f1_score:.3f}")
    print(f"\nDetailed Counts:")
    print(f"True Positives: {metrics.tp}")
    print(f"False Positives: {metrics.fp}")
    print(f"True Negatives: {metrics.tn}")
    print(f"False Negatives: {metrics.fn}")
    print(f"Constant True Negatives: {metrics.constant_tn}")
    print(f"Constant False Negatives: {metrics.constant_fn}")
    print(f"\nTranscripts with Multiple Predictions: {metrics.multi_pred_count}")
    
    # 打印ORF层面的结果
    print("\nORF Level Results:")
    print("=" * 50)
    metrics = results['orf']['metrics']
    print(f"Best Threshold: {results['orf']['threshold']:.3f}")
    print(f"Accuracy: {metrics.accuracy:.3f}")
    print(f"Precision: {metrics.precision:.3f}")
    print(f"Recall: {metrics.recall:.3f}")
    print(f"F1 Score: {metrics.f1_score:.3f}")
    print(f"\nDetailed Counts:")
    print(f"True Positives: {metrics.tp}")
    print(f"False Positives: {metrics.fp}")
    print(f"True Negatives: {metrics.tn}")
    print(f"False Negatives: {metrics.fn}")
    print(f"Constant True Negatives: {metrics.constant_tn}")
    print(f"Constant False Negatives: {metrics.constant_fn}")

if __name__ == "__main__":
    main()


Data Validation and Statistics:

1. 原始数据统计:
总行数: 18200
包含NM_的行数: 13493
包含NR_的行数: 4704

2. 唯一转录本统计:
唯一NM_转录本数: 13415
唯一NR_转录本数: 4616

3. 缺失值统计:
transcript_id        0
tis_position      3515
tts_position      3515
tis_sequence      3515
tts_sequence      3515
tis_prob_score    3515
tts_prob_score    3515
kozak_score       3515
cai_score         3515
gc_score          3515
cds_length        3515
final_score       3515
filter_status        0
true_tis          4867
true_tts          4867
dtype: int64

4. final_score分布:
count                  14685
unique                 14645
top       0.1433264726397471
freq                       4
Name: final_score, dtype: object

5. 每个转录本的预测数量分布:
1    17867
2      162
3        3
Name: count, dtype: int64

Processing NaN values:
Total transcripts with all NaN scores: 3516
NM transcripts: 163
NR transcripts: 3352

Transcript Level Results:
Best Threshold: 0.520
Accuracy: 0.917
Precision: 0.923
Recall: 0.967
F1 Score: 0.944

Detailed Counts:
True Positives