In [8]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import SeqIO
from typing import Dict, List, Tuple
from collections import defaultdict
import os

In [2]:
class TranslationSiteAnalyzer:
    def __init__(self, matching_pkl: str, fasta_file: str, output_dir: str = '.',):
        """
        初始化分析器
        
        Args:
            matching_pkl: matching_predictions.pkl文件路径
            fasta_file: GRCh38_latest_rna.fna文件路径
        """
        self.output_dir = output_dir
        # 加载数据
        with open(matching_pkl, 'rb') as f:
            self.matching_data = pickle.load(f)
            
        # 加载序列数据
        self.sequences = {}
        for record in SeqIO.parse(fasta_file, "fasta"):
            self.sequences[record.id.split('.')[0]] = str(record.seq)
            
        # 设置绘图风格
        sns.set_theme(style="whitegrid")
        
    def analyze_tis_tts_probabilities(self, window_size: int = 6):
        """分析TIS/TTS位点概率分布和上下文概率
    
        Args:
            window_size: 上下文窗口大小,默认为6
        """
        # 初始化数据存储结构 - 分别存储TIS和TTS位点的位置特异性概率
        tis_positions = {f'pos{i}':[] for i in range(-window_size, window_size+4)}
        tts_positions = {f'pos{i}':[] for i in range(-window_size, window_size+4)}
        
        # 新增：存储TIS/TTS三个位点的概率值
        tis_site_probs = []  # 每个元素包含连续3个位点的概率值
        tts_site_probs = []
        
        # 新增：存储TIS/TTS上下游概率值
        tis_upstream_probs = []  # 存储上游window_size个位点的概率值
        tis_downstream_probs = []  # 存储下游window_size个位点的概率值
        tts_upstream_probs = []
        tts_downstream_probs = []
        
        # 遍历每个转录本的预测结果
        for item in self.matching_data:
            probs = item['predictions_probs']  # shape: (seq_len, 3)
            true_labels = item['true_labels']  # shape: (seq_len,)
            seq_len = item['length']
        
            # 在序列中寻找连续的TIS和TTS位置
            for i in range(window_size, seq_len - window_size - 4):
                # 检查是否有连续的3个TIS (label=0)
                if (true_labels[i] == 0 and 
                    true_labels[i+1] == 0 and 
                    true_labels[i+2] == 0):
                    # 收集TIS三个位点的概率值
                    site_probs = [probs[i+j] for j in range(3)]
                    tis_site_probs.append(site_probs)
                    
                    # 收集上下游概率值
                    upstream_probs = [probs[i+j][0] for j in range(-window_size, 0)]
                    downstream_probs = [probs[i+j][0] for j in range(3, window_size+4)]
                    tis_upstream_probs.append(upstream_probs)
                    tis_downstream_probs.append(downstream_probs)
                    
                    # 收集所有位置的概率
                    for j in range(-window_size, window_size+4):
                        if j < 0:  # 前文
                            tis_positions[f'pos{j}'].append(probs[i+j][0])
                        elif j < 3:  # TIS三个位点
                            tis_positions[f'pos{j}'].append(probs[i+j][0])
                        else:  # 后文
                            tis_positions[f'pos{j}'].append(probs[i+j][0])
            
                # 检查是否有连续的3个TTS (label=1)
                if (true_labels[i] == 1 and 
                    true_labels[i+1] == 1 and 
                    true_labels[i+2] == 1):
                    # 收集TTS三个位点的概率值
                    site_probs = [probs[i+j] for j in range(3)]
                    tts_site_probs.append(site_probs)
                    
                    # 收集上下游概率值
                    upstream_probs = [probs[i+j][1] for j in range(-window_size, 0)]
                    downstream_probs = [probs[i+j][1] for j in range(3, window_size+4)]
                    tts_upstream_probs.append(upstream_probs)
                    tts_downstream_probs.append(downstream_probs)
                    
                    # 收集所有位置的概率
                    for j in range(-window_size, window_size+4):
                        if j < 0:  # 前文
                            tts_positions[f'pos{j}'].append(probs[i+j][1])
                        elif j < 3:  # TTS三个位点
                            tts_positions[f'pos{j}'].append(probs[i+j][1])
                        else:  # 后文
                            tts_positions[f'pos{j}'].append(probs[i+j][1])
    
        # 打印统计信息
        print(f"Found {len(tis_site_probs)} TIS sites")
        print(f"Found {len(tts_site_probs)} TTS sites")
        
        # 绘制图表并保存统计信息
        self._plot_site_probabilities(tis_positions, "TIS")
        self._plot_site_probabilities(tts_positions, "TTS")
        
        # 新增：绘制并保存TIS/TTS三位点概率统计
        self._plot_and_save_site_stats(tis_site_probs, "TIS")
        self._plot_and_save_site_stats(tts_site_probs, "TTS")
        
        # 新增：绘制并保存上下游概率统计
        self._plot_and_save_context_stats(tis_upstream_probs, tis_downstream_probs, "TIS")
        self._plot_and_save_context_stats(tts_upstream_probs, tts_downstream_probs, "TTS")
        
        # 保存概率统计信息
        self._print_probability_statistics(tis_positions, "TIS")
        self._print_probability_statistics(tts_positions, "TTS")

    def _plot_and_save_site_stats(self, site_probs: List[List], site_type: str):
        """绘制并保存位点概率统计信息
        
        Args:
            site_probs: 位点概率值列表
            site_type: 位点类型 ('TIS' 或 'TTS')
        """
        # 计算均值和方差
        site_means = np.mean(site_probs, axis=1)
        site_vars = np.var(site_probs, axis=1)
        
        # 绘制均值分布图
        plt.figure(figsize=(8, 6))
        plt.hist(site_means, bins=50, alpha=0.7)
        plt.title(f'{site_type} Site Probability Mean Distribution')
        plt.xlabel('Mean Probability')
        plt.ylabel('Frequency')
        plt.savefig(os.path.join(self.output_dir, f'{site_type}_site_mean_dist.png'))
        plt.close()
        
        # 绘制方差分布图
        plt.figure(figsize=(8, 6))
        plt.hist(site_vars, bins=50, alpha=0.7)
        plt.title(f'{site_type} Site Probability Variance Distribution')
        plt.xlabel('Variance')
        plt.ylabel('Frequency')
        plt.savefig(os.path.join(self.output_dir, f'{site_type}_site_var_dist.png'))
        plt.close()
        
        # 保存统计信息
        stats_file = os.path.join(self.output_dir, f'{site_type}_site_stats.txt')
        with open(stats_file, 'w') as f:
            f.write(f"{site_type} Site Statistics:\n")
            f.write("="*30 + "\n\n")
            f.write("Mean Statistics:\n")
            f.write(f"Overall Mean: {np.mean(site_means):.4f}\n")
            f.write(f"Mean Std: {np.std(site_means):.4f}\n")
            f.write(f"Mean Q1: {np.percentile(site_means, 25):.4f}\n")
            f.write(f"Mean Q3: {np.percentile(site_means, 75):.4f}\n\n")
            f.write("Variance Statistics:\n")
            f.write(f"Overall Variance: {np.mean(site_vars):.4f}\n")
            f.write(f"Variance Std: {np.std(site_vars):.4f}\n")
            f.write(f"Variance Q1: {np.percentile(site_vars, 25):.4f}\n")
            f.write(f"Variance Q3: {np.percentile(site_vars, 75):.4f}\n")
    
    def _plot_and_save_context_stats(self, upstream_probs: List[List], 
                                    downstream_probs: List[List],
                                    site_type: str):
        """绘制并保存上下游概率统计信息
        
        Args:
            upstream_probs: 上游概率值列表
            downstream_probs: 下游概率值列表
            site_type: 位点类型 ('TIS' 或 'TTS')
        """
        # 计算上下游均值和方差
        up_means = np.mean(upstream_probs, axis=1)
        up_vars = np.var(upstream_probs, axis=1)
        down_means = np.mean(downstream_probs, axis=1)
        down_vars = np.var(downstream_probs, axis=1)
        
        # 绘制上游概率分布图
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.hist(up_means, bins=50, alpha=0.7, label='Upstream')
        plt.hist(down_means, bins=50, alpha=0.7, label='Downstream')
        plt.title(f'{site_type} Context Mean Distribution')
        plt.xlabel('Mean Probability')
        plt.ylabel('Frequency')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.hist(up_vars, bins=50, alpha=0.7, label='Upstream')
        plt.hist(down_vars, bins=50, alpha=0.7, label='Downstream')
        plt.title(f'{site_type} Context Variance Distribution')
        plt.xlabel('Variance')
        plt.ylabel('Frequency')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, f'{site_type}_context_dist.png'))
        plt.close()
        
        # 保存统计信息
        stats_file = os.path.join(self.output_dir, f'{site_type}_context_stats.txt')
        with open(stats_file, 'w') as f:
            f.write(f"{site_type} Context Statistics:\n")
            f.write("="*30 + "\n\n")
            
            f.write("Upstream Statistics:\n")
            f.write("-"*20 + "\n")
            f.write(f"Mean: {np.mean(up_means):.4f}\n")
            f.write(f"Std: {np.std(up_means):.4f}\n")
            f.write(f"Variance Mean: {np.mean(up_vars):.4f}\n")
            f.write(f"Variance Std: {np.std(up_vars):.4f}\n\n")
            
            f.write("Downstream Statistics:\n")
            f.write("-"*20 + "\n")
            f.write(f"Mean: {np.mean(down_means):.4f}\n")
            f.write(f"Std: {np.std(down_means):.4f}\n")
            f.write(f"Variance Mean: {np.mean(down_vars):.4f}\n")
            f.write(f"Variance Std: {np.std(down_vars):.4f}\n")
    
    def _plot_site_probabilities(self, position_data, site_type):
        """绘制位点概率分布图
       
        Args:
            position_data: 位置特异性概率数据
            site_type: 位点类型 ('TIS' 或 'TTS')
        """
        plt.figure(figsize=(6, 4))
       
        # 准备数据
        positions = sorted(position_data.keys(), key=lambda x: int(x[3:]))  # 按位置排序
        data = [position_data[pos] for pos in positions]
       
        # 绘制violin plot
        vp = plt.violinplot(data, positions=[i for i in range(len(positions))], 
                          showmeans=True, showextrema=True)
       
        # 设置violin plot颜色
        for pc in vp['bodies']:
            pc.set_facecolor('lightblue')
            pc.set_alpha(0.7)
       
        # 添加scatter plot
        for i, d in enumerate(data):
            scatter_x = np.random.normal(i, 0.05, size=len(d))
            plt.scatter(scatter_x, d, alpha=0.2, c='navy', s=5)
       
        # 设置图表属性
        plt.title(f'{site_type} Position-specific Probability Distribution')
        plt.xlabel('Position relative to site')
        plt.ylabel('Probability')
        plt.xticks(range(len(positions)), [p[3:] for p in positions], rotation=45)
        plt.grid(True, axis='y', linestyle='--', alpha=0.7)
        
        # 保存图表
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, f'{site_type}_position_distribution.png'), dpi=300)
        plt.close()
   
    def _print_probability_statistics(self, position_data, site_type):
        """输出概率统计信息到文件
        
        Args:
            position_data: 位置特异性概率数据
            site_type: 位点类型 ('TIS' 或 'TTS')
        """
        with open(os.path.join(self.output_dir, f'{site_type}_statistics.txt'), 'w') as f:
            f.write(f"{site_type} Probability Statistics:\n")
            f.write("="*30 + "\n\n")
           
            for pos in sorted(position_data.keys(), key=lambda x: int(x[3:])):
                values = position_data[pos]
                if values:
                    f.write(f"\nPosition {pos}:\n")
                    f.write(f"Count: {len(values)}\n")
                    f.write(f"Mean: {np.mean(values):.4f}\n")
                    f.write(f"Std: {np.std(values):.4f}\n")
                    f.write(f"Median: {np.median(values):.4f}\n")
                    f.write(f"Q1: {np.percentile(values, 25):.4f}\n")
                    f.write(f"Q3: {np.percentile(values, 75):.4f}\n")
                    f.write("-"*20 + "\n")

    def analyze_kozak_sequences(self, file_kozak: str):
        """分析Kozak序列得分"""
        # 读取分组信息
        group_a = []
        group_b = []
        with open(file_kozak, 'r') as f:
            header = f.readline().strip().split(';')
            for line in f:
                a, b = line.strip().split(';')
                if a:
                    group_a.append(tuple(a.split(',')))
                if b:
                    group_b.append(tuple(b.split(',')))
        
        # 计算Kozak得分
        scores_a = self._calculate_kozak_scores(group_a)
        scores_b = self._calculate_kozak_scores(group_b)
        
        # 绘制对比图
        plt.figure(figsize=(10, 6))
        
        # 绘制小提琴图
        violin_parts = plt.violinplot([scores_a, scores_b], 
                                    positions=[1, 2])
        
        # 添加散点
        plt.scatter(np.ones_like(scores_a), scores_a, 
                   alpha=0.2, color='blue')
        plt.scatter(np.ones_like(scores_b) * 2, scores_b, 
                   alpha=0.2, color='red')
        
        plt.xticks([1, 2], header)
        plt.ylabel('Kozak Score')
        plt.title('Kozak Sequence Score Distribution')
        plt.savefig(os.path.join(self.output_dir,'kozak_scores.png'))
        plt.close()
        
    def _calculate_kozak_scores(self, group_data: List[Tuple[str, str]]) -> List[float]:
        """计算Kozak序列得分"""
        pwm = self._build_pwm()
        scores = []
        
        for transcript_id, pos in group_data:
            if transcript_id in self.sequences:
                pos = int(pos)
                if pos >= 6 and pos + 4 < len(self.sequences[transcript_id]):
                    seq = self.sequences[transcript_id][pos-6:pos+4]
                    score = self._score_sequence(seq, pwm)
                    scores.append(score)
                    
        return scores

    def analyze_cai_and_length(self, file_cai: str):
        """分析CAI值和CDS长度分布"""
        # 读取分组信息
        group_a = []
        group_b = []
        with open(file_cai, 'r') as f:
            header = f.readline().strip().split(';')
            for line in f:
                a, b = line.strip().split(';')
                if a:
                    group_a.append(tuple(a.split(',')))
                if b:
                    group_b.append(tuple(b.split(',')))
        
        # 计算CAI值和长度
        cai_a, len_a = self._calculate_cai_and_length(group_a)
        cai_b, len_b = self._calculate_cai_and_length(group_b)
        
        # 绘制CAI分布对比图
        self._plot_distribution_comparison(
            cai_a, cai_b, header,
            'CAI Value Distribution',
            'CAI Value',
            'cai_distribution.png'
        )
        
        # 绘制长度分布对比图
        self._plot_distribution_comparison(
            len_a, len_b, header,
            'CDS Length Distribution',
            'Length (bp)',
            'cds_length_distribution.png'
        )
        
    def _calculate_cai_and_length(self, group_data: List[Tuple[str, str]]):
        codon_usage = self._get_codon_usage()
        cai_values = []
        lengths = []
        
        for transcript_id, pos_range in group_data:
            if transcript_id in self.sequences:
                start, end = map(int, pos_range.split('-'))
                cds = self.sequences[transcript_id][start:end]
                
                # 计算CAI
                cai = self._calculate_cai(cds, codon_usage)
                cai_values.append(cai)
                
                # 记录长度
                lengths.append(end - start)
                
        return cai_values, lengths
        
    def _plot_distribution_comparison(self, data_a: List, data_b: List, 
                                    labels: List[str], title: str, 
                                    ylabel: str, filename: str):
        """绘制分布对比图"""
        plt.figure(figsize=(10, 6))
        
        # 绘制小提琴图
        violin_parts = plt.violinplot([data_a, data_b], 
                                    positions=[1, 2])
        
        # 添加散点
        plt.scatter(np.ones_like(data_a), data_a, 
                   alpha=0.2, color='blue')
        plt.scatter(np.ones_like(data_b) * 2, data_b, 
                   alpha=0.2, color='red')
        
        plt.xticks([1, 2], labels)
        plt.ylabel(ylabel)
        plt.title(title)
        plt.savefig(filename)
        plt.close()

    def _build_pwm(self) -> Dict:
        """构建PWM(Position Weight Matrix)矩阵"""
        # 标准Kozak序列PWM
        return {
            -6: {'A': 0.22, 'C': 0.28, 'G': 0.32, 'T': 0.18},
            -5: {'A': 0.20, 'C': 0.30, 'G': 0.30, 'T': 0.20},
            -4: {'A': 0.18, 'C': 0.32, 'G': 0.30, 'T': 0.20},
            -3: {'A': 0.25, 'C': 0.15, 'G': 0.45, 'T': 0.15}, # 重要位点
            -2: {'A': 0.20, 'C': 0.35, 'G': 0.25, 'T': 0.20},
            -1: {'A': 0.20, 'C': 0.35, 'G': 0.25, 'T': 0.20},
            1:  {'A': 0.20, 'C': 0.20, 'G': 0.40, 'T': 0.20}, # ATG后第一个位点
            2:  {'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25},
            3:  {'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25}
        }

    def _score_sequence(self, seq: str, pwm: Dict) -> float:
        """计算序列的PWM得分"""
        score = 1.0
        for i, base in enumerate(seq):
            pos = i - 6  # 相对于ATG的位置
            if pos in pwm and base in pwm[pos]:
                score *= pwm[pos][base]
        return score

    def _get_codon_usage(self) -> Dict[str, float]:
        return {
            # 标准人类密码子使用频率
            'TTT': 0.45, 'TTC': 0.55, 'TTA': 0.07, 'TTG': 0.13,
            'TCT': 0.18, 'TCC': 0.22, 'TCA': 0.15, 'TCG': 0.06,
            'TAT': 0.43, 'TAC': 0.57, 'TAA': 0.28, 'TAG': 0.20,
            'TGT': 0.45, 'TGC': 0.55, 'TGA': 0.52, 'TGG': 1.00,
            'CTT': 0.13, 'CTC': 0.20, 'CTA': 0.07, 'CTG': 0.41,
            'CCT': 0.28, 'CCC': 0.33, 'CCA': 0.27, 'CCG': 0.11,
            'CAT': 0.41, 'CAC': 0.59, 'CAA': 0.25, 'CAG': 0.75,
            'CGT': 0.08, 'CGC': 0.19, 'CGA': 0.11, 'CGG': 0.21,
            'ATT': 0.36, 'ATC': 0.48, 'ATA': 0.16, 'ATG': 1.00,
            'ACT': 0.24, 'ACC': 0.36, 'ACA': 0.28, 'ACG': 0.12,
            'AAT': 0.46, 'AAC': 0.54, 'AAA': 0.42, 'AAG': 0.58,
            'AGT': 0.15, 'AGC': 0.24, 'AGA': 0.20, 'AGG': 0.20,
            'GTT': 0.18, 'GTC': 0.24, 'GTA': 0.11, 'GTG': 0.47,
            'GCT': 0.26, 'GCC': 0.40, 'GCA': 0.23, 'GCG': 0.11,
            'GAT': 0.46, 'GAC': 0.54, 'GAA': 0.42, 'GAG': 0.58,
            'GGT': 0.16, 'GGC': 0.34, 'GGA': 0.25, 'GGG': 0.25
        }

    def _calculate_cai(self, sequence: str, codon_usage: Dict[str, float]) -> float:
        """计算序列的CAI值"""
        if len(sequence) < 3:
            return 0.0
        
        # 将序列按三联体分割
        codons = [sequence[i:i+3] for i in range(0, len(sequence)-2, 3)]
    
        # 计算几何平均值
        values = []
        for codon in codons:
            if codon in codon_usage:
                values.append(codon_usage[codon])
    
        if not values:
            return 0.0
        
        # 计算几何平均值
        return np.exp(np.mean(np.log(values)))

    def analyze_sequence_features(self):
        """分析序列特征并生成报告"""
        # 确保输出目录存在
        os.makedirs(self.output_dir, exist_ok=True)
    
        # 生成统计报告
        report_path = os.path.join(self.output_dir, 'sequence_analysis_report.txt')
        with open(report_path, 'w') as f:
            f.write("Sequence Analysis Report\n")
            f.write("=======================\n\n")
        
            # 1. 基本统计信息
            f.write("1. Basic Statistics\n")
            f.write("-----------------\n")
            f.write(f"Total sequences analyzed: {len(self.matching_data)}\n")
        
            # 按转录本类型统计
            transcript_types = defaultdict(int)
            for item in self.matching_data:
                trans_type = item['transcript_id'].split('_')[0]
                transcript_types[trans_type] += 1
            
            f.write("\nTranscript Type Distribution:\n")
            for t_type, count in transcript_types.items():
                f.write(f"{t_type}: {count} ({count/len(self.matching_data)*100:.2f}%)\n")
            
            # 2. 序列长度分布
            f.write("\n2. Sequence Length Distribution\n")
            f.write("----------------------------\n")
            lengths = [item['length'] for item in self.matching_data]
            f.write(f"Mean length: {np.mean(lengths):.2f}\n")
            f.write(f"Median length: {np.median(lengths):.2f}\n")
            f.write(f"Std length: {np.std(lengths):.2f}\n")
            f.write(f"Min length: {min(lengths)}\n")
            f.write(f"Max length: {max(lengths)}\n")

In [7]:
def main():
    
    # 创建输出目录
    output_dir = "/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/Statictis"
    os.makedirs(output_dir, exist_ok=True)
    
    # 初始化分析器
    analyzer = TranslationSiteAnalyzer(
        matching_pkl="/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/TRANSAID_Embedding_batch4_NM_matching_predictions.pkl",
        fasta_file="/home/jovyan/work/insilico_translation/dataset/GRCh38_latest_rna.fna",
        output_dir=output_dir
    )
    
    # 分析TIS/TTS概率分布
    analyzer.analyze_tis_tts_probabilities()
    
    # 分析Kozak序列
    analyzer.analyze_kozak_sequences("/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/Statictis/kozak_config_file.txt")
    
    # 分析CAI和CDS长度
    analyzer.analyze_cai_and_length("/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/Statictis/CAI_CDS_config_file.txt")
    
    # 生成综合报告
    analyzer.analyze_sequence_features()

if __name__ == "__main__":
    main()

Found 12009 TIS sites
Found 12029 TTS sites


In [4]:
import pickle
import numpy as np
from typing import List, Tuple, Dict
import os
import warnings

def load_pickle_data(file_path: str) -> List[Dict]:
    """加载pickle文件数据
    
    Args:
        file_path: pickle文件路径
    
    Returns:
        List[Dict]: 加载的数据列表
    """
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def find_first_tis_position(array: np.ndarray) -> int:
    """查找第一个连续三个0的起始位置
    
    Args:
        array: 预测或true_labels数组
    
    Returns:
        int: 第一个TIS的起始位置，如果没找到返回-1
    """
    for i in range(len(array)-2):
        if array[i] == 0 and array[i+1] == 0 and array[i+2] == 0:
            return i
    return -1

def collect_group_a_data(nm_data: List[Dict]) -> List[Tuple[str, int]]:
    """收集group A数据（基于true_labels）
    
    Args:
        nm_data: NM转录本预测数据
    
    Returns:
        List[Tuple[str, int]]: 转录本ID和TIS位置的列表
    """
    group_a = []
    for item in nm_data:
        true_labels = np.array(item['true_labels'])
        transcript_id = item['transcript_id']
        
        # 查找所有TIS位置
        tis_positions = []
        for i in range(len(true_labels)-2):
            if (true_labels[i:i+3] == 0).all():
                tis_positions.append(i)
        
        if len(tis_positions) > 1:
            warnings.warn(f"Multiple TIS found in transcript {transcript_id}: {tis_positions}")
            # 只取第一个TIS位置
            group_a.append((transcript_id, tis_positions[0]))
        elif len(tis_positions) == 1:
            group_a.append((transcript_id, tis_positions[0]))
    
    return group_a

def collect_group_b_data(nr_data: List[Dict]) -> List[Tuple[str, int]]:
    """收集group B数据（基于predictions）
    
    Args:
        nr_data: NR转录本预测数据
    
    Returns:
        List[Tuple[str, int]]: 转录本ID和TIS位置的列表
    """
    group_b = []
    for item in nr_data:
        predictions = np.array(item['predictions'])
        transcript_id = item['transcript_id']
        
        # 查找第一个连续三个0的位置
        tis_pos = find_first_tis_position(predictions)
        if tis_pos != -1:
            group_b.append((transcript_id, tis_pos))
    
    return group_b

def save_to_file(group_a: List[Tuple[str, int]], 
                 group_b: List[Tuple[str, int]], 
                 output_file: str):
    """保存数据到文件
    
    Args:
        group_a: group A数据
        group_b: group B数据
        output_file: 输出文件路径
    """
    # 确定两组数据的最大长度
    max_len = max(len(group_a), len(group_b))
    
    with open(output_file, 'w') as f:
        # 写入表头
        f.write('group_A;group_B\n')
        
        # 写入数据
        for i in range(max_len):
            line_parts = []
            
            # Group A数据
            if i < len(group_a):
                line_parts.append(f"{group_a[i][0]},{group_a[i][1]}")
            else:
                line_parts.append("")
                
            # Group B数据
            if i < len(group_b):
                line_parts.append(f"{group_b[i][0]},{group_b[i][1]}")
            else:
                line_parts.append("")
            
            f.write(f"{line_parts[0]};{line_parts[1]}\n")

def main():
    # 设置文件路径
    nm_file = "/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/TRANSAID_Embedding_batch4_NM_matching_predictions.pkl"
    nr_file = "/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/TRANSAID_Embedding_batch4_NR_non_matching_predictions.pkl"
    output_file = "/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/Statictis/kozak_config_file.txt"
    
    # 加载数据
    print("Loading NM data...")
    nm_data = load_pickle_data(nm_file)
    print("Loading NR data...")
    nr_data = load_pickle_data(nr_file)
    
    # 收集数据
    print("Collecting group A data...")
    group_a = collect_group_a_data(nm_data)
    print(f"Found {len(group_a)} TIS sites in group A")
    
    print("Collecting group B data...")
    group_b = collect_group_b_data(nr_data)
    print(f"Found {len(group_b)} TIS sites in group B")
    
    # 保存结果
    print("Saving results...")
    save_to_file(group_a, group_b, output_file)
    print(f"Results saved to {output_file}")

if __name__ == "__main__":
    #main()
    print("skipping")

skipping


In [6]:
import pickle
import numpy as np
from typing import List, Tuple, Dict
import os
import warnings

def load_pickle_data(file_path: str) -> List[Dict]:
    """加载pickle文件数据
    
    Args:
        file_path: pickle文件路径
    
    Returns:
        List[Dict]: 加载的数据列表
    """
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def find_first_tis_position(array: np.ndarray) -> int:
    """查找第一个连续三个0的起始位置
    
    Args:
        array: 预测或true_labels数组
    
    Returns:
        int: 第一个TIS的起始位置，如果没找到返回-1
    """
    for i in range(len(array)-2):
        if array[i] == 0 and array[i+1] == 0 and array[i+2] == 0:
            return i
    return -1

def find_first_tts_end_position(array: np.ndarray) -> int:
    """查找第一个连续三个1的结束位置
    
    Args:
        array: 预测或true_labels数组
    
    Returns:
        int: 第一个TTS的结束位置，如果没找到返回-1
    """
    for i in range(len(array)-2):
        if array[i] == 1 and array[i+1] == 1 and array[i+2] == 1:
            return i + 2  # 返回连续三个1的最后一个位置
    return -1

def collect_group_a_data(nm_data: List[Dict]) -> List[Tuple[str, Tuple[int, int]]]:
    """收集group A数据（基于true_labels）
    
    Args:
        nm_data: NM转录本预测数据
    
    Returns:
        List[Tuple[str, Tuple[int, int]]]: 转录本ID和TIS-TTS位置对的列表
    """
    group_a = []
    for item in nm_data:
        true_labels = np.array(item['true_labels'])
        transcript_id = item['transcript_id']
        
        # 查找TIS位置
        tis_pos = find_first_tis_position(true_labels)
        # 查找TTS位置
        tts_pos = find_first_tts_end_position(true_labels)
        
        if tis_pos != -1 and tts_pos != -1:
            group_a.append((transcript_id, (tis_pos, tts_pos)))
    
    return group_a

def collect_group_b_data(nr_data: List[Dict]) -> List[Tuple[str, Tuple[int, int]]]:
    """收集group B数据（基于predictions）
    
    Args:
        nr_data: NR转录本预测数据
    
    Returns:
        List[Tuple[str, Tuple[int, int]]]: 转录本ID和TIS-TTS位置对的列表
    """
    group_b = []
    for item in nr_data:
        predictions = np.array(item['predictions'])
        transcript_id = item['transcript_id']
        
        # 查找TIS位置
        tis_pos = find_first_tis_position(predictions)
        # 查找TTS位置
        tts_pos = find_first_tts_end_position(predictions)
        
        if tis_pos != -1 and tts_pos != -1:
            group_b.append((transcript_id, (tis_pos, tts_pos)))
    
    return group_b

def save_to_file(group_a: List[Tuple[str, Tuple[int, int]]], 
                 group_b: List[Tuple[str, Tuple[int, int]]], 
                 output_file: str):
    """保存数据到文件
    
    Args:
        group_a: group A数据
        group_b: group B数据
        output_file: 输出文件路径
    """
    # 确定两组数据的最大长度
    max_len = max(len(group_a), len(group_b))
    
    with open(output_file, 'w') as f:
        # 写入表头
        f.write('group_A;group_B\n')
        
        # 写入数据
        for i in range(max_len):
            line_parts = []
            
            # Group A数据
            if i < len(group_a):
                transcript_id, (tis_pos, tts_pos) = group_a[i]
                line_parts.append(f"{transcript_id},{tis_pos}-{tts_pos}")
            else:
                line_parts.append("")
                
            # Group B数据
            if i < len(group_b):
                transcript_id, (tis_pos, tts_pos) = group_b[i]
                line_parts.append(f"{transcript_id},{tis_pos}-{tts_pos}")
            else:
                line_parts.append("")
            
            f.write(f"{line_parts[0]};{line_parts[1]}\n")

def main():
    # 设置文件路径
    nm_file = "/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/TRANSAID_Embedding_batch4_NM_matching_predictions.pkl"
    nr_file = "/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/TRANSAID_Embedding_batch4_NR_non_matching_predictions.pkl"
    output_file = "/home/jovyan/work/insilico_translation/embedding_type3_maxlen9995_ratio80_NM_NR/Statictis/CAI_CDS_config_file.txt"
    
    # 加载数据
    print("Loading NM data...")
    nm_data = load_pickle_data(nm_file)
    print("Loading NR data...")
    nr_data = load_pickle_data(nr_file)
    
    # 收集数据
    print("Collecting group A data...")
    group_a = collect_group_a_data(nm_data)
    print(f"Found {len(group_a)} valid TIS-TTS pairs in group A")
    
    print("Collecting group B data...")
    group_b = collect_group_b_data(nr_data)
    print(f"Found {len(group_b)} valid TIS-TTS pairs in group B")
    
    # 保存结果
    print("Saving results...")
    save_to_file(group_a, group_b, output_file)
    print(f"Results saved to {output_file}")

if __name__ == "__main__":
    #main()
    print("skipping")

skipping
