In [1]:
import os
import re
import glob
import shutil
from pathlib import Path

import pandas as pd

In [2]:
def normalize_strain_name(name: str) -> str:
    """
    将菌株名称统一标准化，适用于 Subject 字段和文件名 stem。

    规则：
    - 下划线 -> 空格
    - ATCC_10987 / ATCC10987 -> ATCC 10987
    - 同理处理 DSM / JCM / NBRC 等编号
    - sp, sp., spp, spp. 统一成 'sp.'
    - 多个空格压缩成一个
    """
    # 1) 下划线全部变空格
    name = name.replace("_", " ")

    # 2) 统一 ATCC / DSM / JCM / NBRC 等收藏号写法
    name = re.sub(r"(ATCC)\s*[_]?\s*(\d+)", r"\1 \2", name, flags=re.IGNORECASE)
    name = re.sub(r"(DSM)\s*[_]?\s*(\d+)",  r"\1 \2", name, flags=re.IGNORECASE)
    name = re.sub(r"(JCM)\s*[_]?\s*(\d+)",  r"\1 \2", name, flags=re.IGNORECASE)
    name = re.sub(r"(NBRC)\s*[_]?\s*(\d+)", r"\1 \2", name, flags=re.IGNORECASE)

    # 3) 统一 sp/sp./spp/spp. 为 'sp.'
    name = re.sub(r"\bsp[p]?\.?\b", "sp.", name, flags=re.IGNORECASE)

    # 4) 多个空格压缩成一个
    name = re.sub(r"\s+", " ", name).strip()

    return name


def build_strain_file_map(genelist_dir: Path) -> dict:
    """
    遍历 GENELIST_DIR 中的 *.genes.tsv，构建
    normalized_strain_name -> 文件路径 的映射。
    """
    strain2file = {}
    duplicated_keys = []

    pattern = str(genelist_dir / "*.genes.tsv")
    for path in glob.glob(pattern):
        path = Path(path)
        fname = path.name                       # Bacillus_cereus_ATCC_10987.genes.tsv
        stem = fname[:-len(".genes.tsv")]       # Bacillus_cereus_ATCC_10987
        strain_name = normalize_strain_name(stem)

        if strain_name in strain2file:
            duplicated_keys.append(strain_name)
        strain2file[strain_name] = str(path)

    print(f"[INFO] 共索引到 {len(strain2file)} 个基因列表文件")
    if duplicated_keys:
        print(f"[WARN] 有 {len(duplicated_keys)} 个标准化后名称重复，将使用最后一个文件：")
        for k in duplicated_keys[:10]:
            print("       -", k)
        if len(duplicated_keys) > 10:
            print("       ...")

    return strain2file


def extract_gene_triples(kg_path: Path, out_path: Path) -> pd.DataFrame:
    """
    从 AResKG 全量文件中筛选和 gene / mutation / 抗性相关的谓词，
    保存到中间文件并返回 DataFrame。
    """
    print(f"[INFO] 读取 KG：{kg_path}")
    kg = pd.read_csv(kg_path, sep="\t")

    predicates = ["has gene", "has mutation", "resistant to", "sensitive to"]
    df_gene = kg[kg["Predicate"].isin(predicates)].copy()

    out_path.parent.mkdir(parents=True, exist_ok=True)
    df_gene.to_csv(out_path, sep="\t", index=False)
    print(f"[INFO] 已将筛选后的三元组写入：{out_path}，共 {len(df_gene)} 行")
    return df_gene


def mark_subjects_with_genelist(df_gene: pd.DataFrame, strain2file: dict):
    """
    给 df_gene 新增：
      - NormalizedSubject
      - has_genelist (是否能在 strain2file 中找到)
    并返回 (subjects_with_genelist, subjects_without_genelist)
    """
    df_gene = df_gene.copy()
    df_gene["NormalizedSubject"] = df_gene["Subject"].astype(str).apply(normalize_strain_name)

    df_gene["has_genelist"] = df_gene["NormalizedSubject"].isin(strain2file.keys())

    subjects_with = sorted(df_gene.loc[df_gene["has_genelist"], "NormalizedSubject"].unique())
    subjects_without = sorted(df_gene.loc[~df_gene["has_genelist"], "NormalizedSubject"].unique())

    print(f"[INFO] 匹配成功菌株数：{len(subjects_with)}")
    print(f"[INFO] 匹配失败菌株数：{len(subjects_without)}")

    return df_gene, subjects_with, subjects_without


def save_subjects_without_genelist(subjects_without: list, processed_dir: Path):
    processed_dir.mkdir(parents=True, exist_ok=True)
    out_path = processed_dir / "subjects_without_genelist.tsv"
    pd.Series(subjects_without).to_csv(out_path, sep="\t", index=False, header=False)
    print(f"[INFO] 未匹配到 genelist 的菌株列表已保存至：{out_path}")


def copy_matched_genelist_files(subjects_with: list, strain2file: dict, output_dir: Path):
    output_dir.mkdir(parents=True, exist_ok=True)
    copied = 0

    for norm_name in subjects_with:
        src = strain2file.get(norm_name)
        if src is None:
            # 理论上不会出现，但防御性处理一下
            print(f"[WARN] {norm_name} 在 strain2file 中未找到对应文件，跳过。")
            continue

        src_path = Path(src)
        dst_path = output_dir / src_path.name

        shutil.copy2(src_path, dst_path)
        copied += 1

    print(f"[DONE] 已复制 {copied} 个基因文件到：{output_dir}")


def save_df_gene_has_genelist(df_gene: pd.DataFrame, out_dir: Path, n_strains: int):
    out_dir.mkdir(parents=True, exist_ok=True)
    df_has = df_gene.loc[df_gene["has_genelist"]].reset_index(drop=True)
    out_path = out_dir / f"strain_species_{n_strains}_norm.tsv"
    df_has.to_csv(out_path, sep="\t", index=False)
    print(f"[INFO] 已保存含 genelist 的三元组子集：{out_path}，行数={len(df_has)}")

In [3]:
current_path = Path.cwd()
home_path = current_path.parent

# ---- 路径配置 ----
kg_file = home_path / "data/0-raw_data/AResKG_1117.txt"
processed_dir = home_path / "data/1-processed_data"
genelist_dir = Path("/apdcephfs_qy3/share_2932069/kangcz/StrainNetwork/Strain")
output_genelist_dir = home_path / "data/2-strains_gene_list"

In [4]:
# 1) 从 KG 中抽取与 gene/抗性相关的三元组
strain_species_kg_path = processed_dir / "strain_species_kg.tsv"
if strain_species_kg_path.exists():
    print(f"[INFO] 发现已有中间文件，直接读取：{strain_species_kg_path}")
    df_gene = pd.read_csv(strain_species_kg_path, sep="\t")
else:
    df_gene = extract_gene_triples(kg_file, strain_species_kg_path)
df_gene.head()

[INFO] 发现已有中间文件，直接读取：/opt/ai4g_chriszyyang/buddy1/2_project_ongoing/4-antibio_resistance/PANACEA/data/1-processed_data/strain_species_kg.tsv


Unnamed: 0,Subject,Predicate,Object,Species
0,Clostridioides difficile NAPCR1,has gene,vanG,Clostridioides difficile
1,Clostridioides difficile NAPCR1,resistant to,Vancomycin,Clostridioides difficile
2,Clostridioides difficile R20291,has gene,vanG,Clostridioides difficile
3,Clostridioides difficile R20291,has mutation,R314L,Clostridioides difficile
4,Clostridioides difficile R20291,resistant to,Vancomycin,Clostridioides difficile


In [5]:
# 2) 建立 normalized strain name -> genes.tsv 映射
strain2file = build_strain_file_map(genelist_dir)

[INFO] 共索引到 1747 个基因列表文件
[WARN] 有 12 个标准化后名称重复，将使用最后一个文件：
       - Pseudomonas aeruginosa ATCC 27853
       - Escherichia coli ATCC 8739
       - Acinetobacter baumannii ATCC 17978
       - Staphylococcus epidermidis ATCC 12228
       - Helicobacter pylori ATCC 43504
       - Streptococcus pneumoniae ATCC 49619
       - Bacillus subtilis ATCC 6633
       - Mycobacterium tuberculosis H37Rv ATCC 27294
       - Helicobacter pylori ATCC 700392
       - Acinetobacter baumannii ATCC 19606
       ...


In [6]:
# 3) 标记 df_gene 中哪些主语有 genelist
df_gene_marked, subjects_with, subjects_without = mark_subjects_with_genelist(
    df_gene, strain2file
)

[INFO] 匹配成功菌株数：866
[INFO] 匹配失败菌株数：819


In [7]:
# 4) 保存未匹配菌株列表
save_subjects_without_genelist(subjects_without, processed_dir)

[INFO] 未匹配到 genelist 的菌株列表已保存至：/opt/ai4g_chriszyyang/buddy1/2_project_ongoing/4-antibio_resistance/PANACEA/data/1-processed_data/subjects_without_genelist.tsv


In [8]:
# 5) 复制匹配成功的基因列表文件
copy_matched_genelist_files(subjects_with, strain2file, output_genelist_dir)

[DONE] 已复制 866 个基因文件到：/opt/ai4g_chriszyyang/buddy1/2_project_ongoing/4-antibio_resistance/PANACEA/data/2-strains_gene_list


In [9]:
# 6) 保存只保留 has_genelist 的 df_gene 子集
save_df_gene_has_genelist(df_gene_marked, processed_dir, len(subjects_with))

[INFO] 已保存含 genelist 的三元组子集：/opt/ai4g_chriszyyang/buddy1/2_project_ongoing/4-antibio_resistance/PANACEA/data/1-processed_data/strain_species_866_norm.tsv，行数=8824
