In [16]:
import pandas as pd
import numpy as np
from typing import Dict, Tuple, List
from dataclasses import dataclass
import matplotlib.pyplot as plt
from collections import defaultdict
import os

@dataclass
class PerformanceMetrics:
    """性能指标数据类"""
    threshold: float
    tp: int  # 真阳性
    fp: int  # 假阳性
    tn: int  # 真阴性
    fn: int  # 假阴性
    multi_pred_count: int  # 具有多个预测结果的转录本数量
    constant_fn: int  # 由于NaN导致的恒定假阴性数量
    constant_tn: int  # 由于NaN导致的恒定真阴性数量
    
    @property
    def accuracy(self) -> float:
        """计算准确率"""
        total = self.tp + self.fp + self.tn + self.fn + self.constant_tn + self.constant_fn
        return (self.tp + self.tn + self.constant_tn) / total if total > 0 else 0
    
    @property
    def precision(self) -> float:
        """计算精确率"""
        denominator = self.tp + self.fp
        return self.tp / denominator if denominator > 0 else 0
    
    @property
    def recall(self) -> float:
        """计算召回率"""
        denominator = self.tp + self.fn + self.constant_fn
        return self.tp / denominator if denominator > 0 else 0
    
    @property
    def specificity(self) -> float:
        """计算特异性（真阴性率）"""
        denominator = self.tn + self.fp + self.constant_tn
        return (self.tn + self.constant_tn) / denominator if denominator > 0 else 0
    
    @property
    def fpr(self) -> float:
        """计算假阳性率"""
        return 1 - self.specificity
    
    @property
    def f1_score(self) -> float:
        """计算F1分数"""
        prec = self.precision
        rec = self.recall
        denominator = prec + rec
        return 2 * (prec * rec) / denominator if denominator > 0 else 0

class ThresholdOptimizer:
    """阈值优化器类"""
    
    def __init__(self, data_file: str, file_dir: str):
        """
        初始化阈值优化器
        
        Args:
            data_file: 包含预测结果的CSV文件名
            file_dir: 文件目录路径
        """
        self.file_dir = file_dir
        self.data_file = os.path.join(file_dir, data_file)
        self.df = pd.read_csv(self.data_file)
        
        # 数据验证和统计
        self._validate_and_print_stats()
        
        # 转换final_score为数值类型并处理NaN值
        self._process_final_scores()
        
        # 按transcript_id分组，获取每个转录本的所有预测
        self.grouped_predictions = self.df.groupby('transcript_id')
    
    def _process_final_scores(self):
        """转换final_score为数值类型，并为NaN值分配随机值"""
        # 将final_score转换为数值类型
        self.df['final_score'] = pd.to_numeric(self.df['final_score'], errors='coerce')
        
        # 统计NaN值的分布情况
        nan_transcripts = self.df.groupby('transcript_id')['final_score'].apply(lambda x: x.isna().all())
        nan_transcripts = nan_transcripts[nan_transcripts]
        
        print("\nProcessing NaN values:")
        print(f"Total transcripts with all NaN scores: {len(nan_transcripts)}")
        print(f"NM transcripts: {sum(i.startswith('NM_') for i in nan_transcripts.index)}")
        print(f"NR transcripts: {sum(i.startswith('NR_') for i in nan_transcripts.index)}")
        
        # 为每个全是NaN的转录本分配随机值
        np.random.seed(42)  # 设置随机种子确保可重复性
        for transcript_id in nan_transcripts.index:
            # 为该转录本的所有行分配相同的随机值
            random_score = np.random.uniform(0, 0.1)
            self.df.loc[self.df['transcript_id'] == transcript_id, 'final_score'] = random_score

    def _plot_transcript_scores(self, results: List[PerformanceMetrics]):
        """
        绘制转录本得分的散点图，区分NM和NR转录本，并用不同颜色标识预测正确和错误的案例。
        """
        # 收集数据点
        transcript_data = {
            'NM': {'correct': [], 'incorrect': []},
            'NR': {'correct': [], 'incorrect': []}
        }
        
        # 使用最优阈值的结果
        best_threshold = max(results, key=lambda x: x.f1_score).threshold
        
        for transcript_id, group in self.grouped_predictions:
            max_score = group['final_score'].max()
            is_nm = transcript_id.startswith('NM_')
            
            if is_nm:
                has_match = False
                if max_score > best_threshold:
                    for _, row in group.iterrows():
                        if (pd.notna(row['true_tis']) and 
                            pd.notna(row['true_tts']) and 
                            row['tis_position'] == row['true_tis'] and 
                            row['tts_position'] == row['true_tts']):
                            has_match = True
                            break
                    if has_match:
                        transcript_data['NM']['correct'].append(max_score)
                    else:
                        transcript_data['NM']['incorrect'].append(max_score)
                else:
                    transcript_data['NM']['incorrect'].append(max_score)
            else:
                if max_score > best_threshold:
                    transcript_data['NR']['incorrect'].append(max_score)
                else:
                    transcript_data['NR']['correct'].append(max_score)
        
        # 创建散点图
        plt.figure(figsize=(10, 6))
        
        # 设置抖动范围
        jitter_range = 0.2
        
        # 为NM转录本绘制点
        x_nm_correct = np.random.normal(0, jitter_range, len(transcript_data['NM']['correct']))
        x_nm_incorrect = np.random.normal(0, jitter_range, len(transcript_data['NM']['incorrect']))
        plt.scatter(x_nm_correct, transcript_data['NM']['correct'], 
                c='blue', alpha=0.5, label='NM Correct', s=1)
        plt.scatter(x_nm_incorrect, transcript_data['NM']['incorrect'], 
                c='red', alpha=0.5, label='NM Incorrect',s=1)
        
        # 为NR转录本绘制点
        x_nr_correct = np.random.normal(1, jitter_range, len(transcript_data['NR']['correct']))
        x_nr_incorrect = np.random.normal(1, jitter_range, len(transcript_data['NR']['incorrect']))
        plt.scatter(x_nr_correct, transcript_data['NR']['correct'], 
                c='blue', alpha=0.5, s=1)
        plt.scatter(x_nr_incorrect, transcript_data['NR']['incorrect'], 
                c='red', alpha=0.5, s=1)
        
        # 绘制最优阈值线
        plt.axhline(y=best_threshold, color='g', linestyle='--', 
                    label=f'Threshold ({best_threshold:.3f})')
        
        # 设置图表属性
        plt.xticks([0, 1], ['NM', 'NR'])
        plt.ylabel('Final Score')
        plt.title('Transcript Score Distribution')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # 添加统计信息
        stats_text = (
            f"NM Correct: {len(transcript_data['NM']['correct'])}\n"
            f"NM Incorrect: {len(transcript_data['NM']['incorrect'])}\n"
            f"NR Correct: {len(transcript_data['NR']['correct'])}\n"
            f"NR Incorrect: {len(transcript_data['NR']['incorrect'])}"
        )
        plt.figtext(1.15, 0.5, stats_text, fontsize=10, 
                    bbox=dict(facecolor='white', alpha=0.8))
        
        # 调整布局以适应图例和统计信息
        plt.tight_layout()
        plt.subplots_adjust(right=0.85)
        
        # 保存图表
        plt.savefig(os.path.join(self.file_dir, 'transcript_scores.png'), 
                    bbox_inches='tight', dpi=300)
        plt.close()
    
    def evaluate_threshold(self, threshold: float) -> PerformanceMetrics:
        """评估特定阈值的性能"""
        tp = fp = tn = fn = 0
        multi_pred_count = 0
        
        # 使用集合来跟踪已处理的转录本
        processed_transcripts = set()
        
        for transcript_id, group in self.grouped_predictions:
            # 确保每个转录本只处理一次
            if transcript_id in processed_transcripts:
                continue
                
            processed_transcripts.add(transcript_id)
            is_nm = transcript_id.startswith('NM_')
            max_score = group['final_score'].max()
            
            # 统计多预测结果的转录本
            if len(group) > 1 and group['filter_status'].eq('passed').any():
                multi_pred_count += 1
            
            if max_score > threshold:
                if is_nm:
                    has_match = False
                    for _, row in group.iterrows():
                        if (pd.notna(row['true_tis']) and 
                            pd.notna(row['true_tts']) and 
                            row['tis_position'] == row['true_tis'] and 
                            row['tts_position'] == row['true_tts']):
                            has_match = True
                            break
                    if has_match:
                        tp += 1
                    else:
                        fp += 1
                else:
                    fp += 1
            else:
                if is_nm:
                    fn += 1
                else:
                    tn += 1
        
        # 验证总数
        total_counted = tp + fp + tn + fn
        expected_total = len(self.grouped_predictions)
        if total_counted != expected_total:
            print(f"Warning: Count mismatch. Counted: {total_counted}, Expected: {expected_total}")
        
        return PerformanceMetrics(
            threshold=threshold,
            tp=tp,
            fp=fp,
            tn=tn,
            fn=fn,
            multi_pred_count=multi_pred_count,
            constant_fn=0,
            constant_tn=0
        )
    
    def _save_performance_data(self, results: List[PerformanceMetrics], output_file: str):
        """保存性能指标数据到文件"""
        with open(output_file, 'w') as f:
            # 写入表头
            f.write("Threshold\tTP\tFP\tTN\tFN\tConstant_TN\tConstant_FN\tMulti_Pred_Count\t")
            f.write("Accuracy\tPrecision\tRecall\tSpecificity\tFPR\tF1_Score\n")
            
            # 写入数据
            for metrics in results:
                f.write(f"{metrics.threshold:.3f}\t")
                f.write(f"{metrics.tp}\t")
                f.write(f"{metrics.fp}\t")
                f.write(f"{metrics.tn}\t")
                f.write(f"{metrics.fn}\t")
                f.write(f"{metrics.constant_tn}\t")
                f.write(f"{metrics.constant_fn}\t")
                f.write(f"{metrics.multi_pred_count}\t")
                f.write(f"{metrics.accuracy:.3f}\t")
                f.write(f"{metrics.precision:.3f}\t")
                f.write(f"{metrics.recall:.3f}\t")
                f.write(f"{metrics.specificity:.3f}\t")
                f.write(f"{metrics.fpr:.3f}\t")
                f.write(f"{metrics.f1_score:.3f}\n")
    
    def find_optimal_threshold(self, 
                             start: float = 0.0, 
                             end: float = 1.0, 
                             step: float = 0.01) -> Tuple[float, PerformanceMetrics]:
        """
        寻找最优阈值
        
        Args:
            start: 起始阈值
            end: 结束阈值
            step: 步长
            
        Returns:
            Tuple[float, PerformanceMetrics]: 最优阈值和对应的性能指标
        """
        thresholds = np.arange(start, end + step, step)
        best_metrics = None
        best_f1 = -1
        
        # 存储所有结果用于绘图
        results = []
        
        for threshold in thresholds:
            metrics = self.evaluate_threshold(threshold)
            results.append(metrics)
            
            if metrics.f1_score > best_f1:
                best_f1 = metrics.f1_score
                best_metrics = metrics
        
        # 生成输出文件名
        output_base = self.data_file.rsplit('.', 1)[0]
        output_file = f"{output_base}_performence_search.txt"
        
        # 保存性能指标数据
        self._save_performance_data(results, output_file)
        
        # 绘制性能曲线
        self._plot_performance_curves(results)
        
        # 绘制ROC曲线
        self._plot_roc_curve(results)

        # 添加新的散点图
        self._plot_transcript_scores(results)
        
        return best_metrics.threshold, best_metrics
    
    def _plot_performance_curves(self, results: List[PerformanceMetrics]):
        """绘制性能曲线"""
        thresholds = [r.threshold for r in results]
        metrics = {
            'Accuracy': [r.accuracy for r in results],
            'Precision': [r.precision for r in results],
            'Recall': [r.recall for r in results],
            'F1 Score': [r.f1_score for r in results]
        }
        
        plt.figure(figsize=(12, 6))
        for metric_name, values in metrics.items():
            plt.plot(thresholds, values, label=metric_name)
        
        plt.xlabel('Threshold')
        plt.ylabel('Score')
        plt.title('Performance Metrics vs Threshold')
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(self.file_dir, 'performance_curves.png'))
        plt.close()
    
    def _plot_roc_curve(self, results: List[PerformanceMetrics]):
        """绘制ROC曲线"""
        fpr = [r.fpr for r in results]
        tpr = [r.recall for r in results]  # TPR = Recall
        
        # 计算AUC
        auc = np.trapz(tpr, fpr)
        
        plt.figure(figsize=(8, 8))
        plt.plot(fpr, tpr, 'b-', label=f'ROC (AUC = {auc:.3f})')
        plt.plot([0, 1], [0, 1], 'r--', label='Random')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC) Curve')
        plt.legend(loc='lower right')
        plt.grid(True)
        plt.savefig(os.path.join(self.file_dir, 'roc_curve.png'))
        plt.close()
    
    def _validate_and_print_stats(self):
        """验证数据并打印统计信息"""
        print("\nData Validation and Statistics:")
        print("=" * 50)
        
        # 原始转录本计数
        nm_transcripts = self.df['transcript_id'].str.startswith('NM_')
        nr_transcripts = self.df['transcript_id'].str.startswith('NR_')
        
        print("\n1. 原始数据统计:")
        print(f"总行数: {len(self.df)}")
        print(f"包含NM_的行数: {nm_transcripts.sum()}")
        print(f"包含NR_的行数: {nr_transcripts.sum()}")
        
        # 唯一转录本计数
        unique_nm = self.df[nm_transcripts]['transcript_id'].nunique()
        unique_nr = self.df[nr_transcripts]['transcript_id'].nunique()
        
        print("\n2. 唯一转录本统计:")
        print(f"唯一NM_转录本数: {unique_nm}")
        print(f"唯一NR_转录本数: {unique_nr}")
        
        # 检查缺失值
        print("\n3. 缺失值统计:")
        print(self.df.isnull().sum())
        
        # 检查final_score列的值分布
        print("\n4. final_score值分布:")
        print(self.df['final_score'].describe())
        
        # 检查是否有重复的transcript_id和position组合
        duplicates = self.df.groupby('transcript_id').size()
        print("\n5. 每个转录本的预测数量分布:")
        print(duplicates.value_counts().sort_index())
    
    def _convert_final_score(self):
        """转换final_score为数值类型，并打印无法转换的值"""
        # 在转换之前查看无法转换为数值的值
        non_numeric = pd.to_numeric(self.df['final_score'], errors='coerce').isna()
        if non_numeric.any():
            print("\n无法转换为数值的final_score值:")
            print(self.df[non_numeric][['transcript_id', 'final_score']])
        
        # 转换为数值类型
        self.df['final_score'] = pd.to_numeric(self.df['final_score'], errors='coerce')

def main():
    """主函数"""
    # 设置输入文件路径
    input_file = 'TRANSAID_Embedding_batch4_ALL_tis_tts_pairs.csv'
    file_dir = '/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/TTSTIS_selection'
    
    # 创建优化器
    optimizer = ThresholdOptimizer(input_file, file_dir)
    
    # 寻找最优阈值
    best_threshold, best_metrics = optimizer.find_optimal_threshold()
    
    # 打印结果
    print("\nOptimal Threshold Results:")
    print("=" * 50)
    print(f"Best Threshold: {best_threshold:.3f}")
    print(f"Accuracy: {best_metrics.accuracy:.3f}")
    print(f"Precision: {best_metrics.precision:.3f}")
    print(f"Recall: {best_metrics.recall:.3f}")
    print(f"Specificity: {best_metrics.specificity:.3f}")
    print(f"F1 Score: {best_metrics.f1_score:.3f}")
    print(f"\nDetailed Counts:")
    print(f"True Positives: {best_metrics.tp}")
    print(f"False Positives: {best_metrics.fp}")
    print(f"True Negatives: {best_metrics.tn}")
    print(f"False Negatives: {best_metrics.fn}")
    print(f"\nTranscripts with Multiple Predictions: {best_metrics.multi_pred_count}")

if __name__ == "__main__":
    main()


Data Validation and Statistics:

1. 原始数据统计:
总行数: 18200
包含NM_的行数: 13493
包含NR_的行数: 4704

2. 唯一转录本统计:
唯一NM_转录本数: 13415
唯一NR_转录本数: 4616

3. 缺失值统计:
transcript_id        0
tis_position      3515
tts_position      3515
tis_sequence      3515
tts_sequence      3515
tis_prob_score    3515
tts_prob_score    3515
kozak_score       3515
cai_score         3515
gc_score          3515
cds_length        3515
final_score       3515
filter_status        0
true_tis          4867
true_tts          4867
dtype: int64

4. final_score值分布:
count                  14685
unique                 14645
top       0.1433264726397471
freq                       4
Name: final_score, dtype: object

5. 每个转录本的预测数量分布:
1    17867
2      162
3        3
Name: count, dtype: int64

Processing NaN values:
Total transcripts with all NaN scores: 3516
NM transcripts: 163
NR transcripts: 3352

Optimal Threshold Results:
Best Threshold: 0.520
Accuracy: 0.919
Precision: 0.925
Recall: 0.967
Specificity: 0.787
F1 Score: 0.946

Detailed C