In [1]:
%pip install numpy Bio


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
# save_as_fasta.py
import numpy as np
from pathlib import Path
from Bio import SeqIO
from Bio.Seq import Seq
from tqdm import tqdm  # 新增导入

def onehot_to_seq(onehot):
    """将one-hot编码转换为核苷酸序列"""
    nucleotides = ['A', 'T', 'C', 'G']
    seq_length = len(onehot) // 4
    reshaped_onehot = onehot.reshape(seq_length, 4)
    return ''.join([nucleotides[np.argmax(vec)] for vec in reshaped_onehot])

def convert_viRNAtrap_to_FASTA(
    x_path: str,
    y_path: str,
    output_dir: str,
    dataset_type: str = "train"
):
    # 加载数据
    X = np.load(x_path)
    y = np.load(y_path).astype(int)
    
    # 创建输出目录
    virus_dir = Path(output_dir) / dataset_type / "viruses"
    host_dir = Path(output_dir) / dataset_type / "host"
    virus_dir.mkdir(parents=True, exist_ok=True)
    host_dir.mkdir(parents=True, exist_ok=True)
    
    virus_records = []
    host_records = []
    
    # 添加带进度条的循环
    total_samples = len(X)
    progress_bar = tqdm(
        zip(X, y), 
        total=total_samples,
        desc=f"转换 {dataset_type} 数据集",
        unit="seq"
    )
    
    for i, (seq_onehot, label) in enumerate(progress_bar):
        seq = onehot_to_seq(seq_onehot)
        record = SeqIO.SeqRecord(
            seq=Seq(seq),
            id=f"{'virus' if label == 1 else 'host'}_{dataset_type}_{i}",
            description=""
        )
        if label == 1:
            virus_records.append(record)
        else:
            host_records.append(record)
    
    # 写入文件
    SeqIO.write(virus_records, virus_dir / "sequences.fasta", "fasta")
    SeqIO.write(host_records, host_dir / "sequences.fasta", "fasta")
    print(f"\n✅ {dataset_type} 数据集转换完成！输出目录：{output_dir}/{dataset_type}")

if __name__ == "__main__":
    # 安装tqdm（如果尚未安装）
    try:
        from tqdm import tqdm
    except ImportError:
        print("正在安装 tqdm...")
        import subprocess
        subprocess.check_call(["pip", "install", "tqdm"])
        from tqdm import tqdm
    
    # 训练集转换
    convert_viRNAtrap_to_FASTA(
        x_path="train_test_data/train_x.npy",
        y_path="train_test_data/train_y.npy",
        output_dir="virhunter_data",
        dataset_type="train"
    )
    
    # 测试集转换
    convert_viRNAtrap_to_FASTA(
        x_path="train_test_data/test_x.npy",
        y_path="train_test_data/test_y.npy",
        output_dir="virhunter_data",
        dataset_type="test"
    )
    # test_onehot = np.array([1,0,0,0, 0,1,0,0])  # 对应 'A','T'
    # print(onehot_to_seq(test_onehot))  # 应输出 'AT'

转换 train 数据集: 100%|██████████| 8000000/8000000 [04:00<00:00, 33242.02seq/s]



✅ train 数据集转换完成！输出目录：virhunter_data/train


转换 test 数据集: 100%|██████████| 2558044/2558044 [01:05<00:00, 39216.16seq/s]



✅ test 数据集转换完成！输出目录：virhunter_data/test
