In [3]:
# -*- coding: utf-8 -*-
"""
ADNI preprocessing (one-pass, simplified & robust):
- Build Y (baseline CN/AD), PTID list
- Read multiple PLINK sets, take common SNPs (keeping order of the first set)
- Filter invalid loci, assemble genotype matrix (samples x SNPs)
- Overlap with Gencode v19 (hg19) to label exonic vs non-exonic
- Split to X (exonic), C (non-exonic)
- Deduplicate & align to baseline Y by PTID (keep first occurrence)
- Build covariates [sex_male, gender_male] (robust to duplicate column names)
- Save aligned X, C, Y, covariates, ids to Prepared_data2

Requirements:
  pip install pandas numpy pandas-plink pyranges scikit-learn
"""

from pathlib import Path
import gzip, shutil, re, sys
import numpy as np
import pandas as pd
from pandas_plink import read_plink
import pyranges as pr
from sklearn.preprocessing import normalize

# ---------- Paths ----------
ROOT = Path("/Users/zhangjiahui/Desktop/Haohan Research/Alz GWAS data")
OUT  = ROOT / "Prepared_data4"
OUT.mkdir(parents=True, exist_ok=True)

DIAG_CSV = ROOT / "ADNI_Diagnosis.csv"
DEMO_CSV = ROOT / "ADNI_Demographics.csv"

# Gencode v19 (hg19/GRCh37)
GTF_GZ    = ROOT / "gencode.v19.annotation.gtf.gz"
GTF_PLAIN = ROOT / "gencode.v19.annotation.gtf"
GTF = GTF_PLAIN if GTF_PLAIN.exists() else GTF_GZ

PLINK_DIRS = [
    ROOT / "ADNI1_GWAS_PLINK",
    ROOT / "ADNI_GO2_GWAS_PLINK",
    ROOT / "ADNI_GO2_GWAS_PLINK2",
    ROOT / "ADNI3_GWAS_PLINK1",
    ROOT / "ADNI3_GWAS_PLINK2",
]

# ---------- Helpers ----------
_ptid_pat = re.compile(r"(\d{3}_S_\d{4})")

def norm_ptid(s: str) -> str | None:
    s = str(s).strip()
    m = _ptid_pat.search(s)
    return m.group(1) if m else None

def detect_bfile(folder: Path) -> Path:
    cands = []
    for bed in folder.glob("*.bed"):
        stem = bed.with_suffix("")
        if stem.with_suffix(".bim").exists() and stem.with_suffix(".fam").exists():
            cands.append(stem)
    if len(cands) != 1:
        raise RuntimeError(f"[{folder.name}] expected 1 PLINK trio, got {len(cands)}: {cands}")
    return cands[0]

def split_plink(a, b, c):
    """Identify (G, bim, fam) robustly."""
    objs = [a, b, c]
    dfs  = [o for o in objs if hasattr(o, "columns")]
    arrs = [o for o in objs if not hasattr(o, "columns")]
    if not (len(dfs) == 2 and len(arrs) == 1):
        raise TypeError(f"Unexpected read_plink returns: {[type(o) for o in objs]}")
    def is_bim(df): return {'chrom','pos','snp'}.issubset(set(df.columns))
    def is_fam(df): return 'iid' in df.columns
    if is_bim(dfs[0]) and is_fam(dfs[1]): return arrs[0], dfs[0], dfs[1]
    if is_bim(dfs[1]) and is_fam(dfs[0]): return arrs[0], dfs[1], dfs[0]
    raise ValueError(f"Cannot detect BIM/FAM. DF1 cols={list(dfs[0].columns)}, DF2 cols={list(dfs[1].columns)}")

def to_chr_label(val) -> str:
    s = str(val)
    if s in ['23','X','x']: return 'chrX'
    if s in ['24','Y','y']: return 'chrY'
    if s in ['25','MT','Mt','mt','M','m']: return 'chrM'
    if s.lower().startswith('chr'):
        return s if s.startswith('chr') else 'chr' + s.split('chr',1)[-1]
    return 'chr' + s

def ensure_gtf_plain(gtf_path: Path) -> Path:
    if gtf_path.suffix == ".gz":
        if not GTF_PLAIN.exists():
            print("[INFO] Decompressing GTF…")
            with gzip.open(gtf_path, "rb") as f_in, open(GTF_PLAIN, "wb") as f_out:
                shutil.copyfileobj(f_in, f_out)
        return GTF_PLAIN
    return gtf_path

def get_first_series(df: pd.DataFrame, col: str) -> pd.Series:
    """Return the first column named `col` as a Series, even if duplicates exist."""
    obj = df[col]
    if isinstance(obj, pd.DataFrame):
        return obj.iloc[:, 0]
    return obj

def make_male_from_column(df: pd.DataFrame, col: str) -> pd.Series:
    s = get_first_series(df, col).astype(str).str.upper().str.strip()
    male = pd.Series(np.nan, index=df.index, dtype=float)
    male[s.isin(["M", "MALE", "1"])] = 1.0
    male[s.isin(["F", "FEMALE", "0", "2"])] = 0.0
    # fuzzy fallbacks
    male[s.str.contains("MALE", na=False)] = 1.0
    male[s.str.contains("FEM",  na=False)] = 0.0
    return male

def drop_nan_inf_cols(A: np.ndarray, name: str):
    A = np.asarray(A, dtype=np.float64)
    bad_nan = np.isnan(A).any(axis=0)
    bad_inf = ~np.isfinite(A).all(axis=0)
    keep = ~(bad_nan | bad_inf)
    removed = int((~keep).sum())
    if removed > 0:
        print(f"[CLEAN] {name}: removed {removed} columns with NaN/Inf; kept {keep.sum()}.")
    else:
        print(f"[CLEAN] {name}: no NaN/Inf columns; kept {A.shape[1]}.")
    return A[:, keep], keep

# ---------- 1) Build baseline Y (CN=0, AD=1) ----------
diag = pd.read_csv(DIAG_CSV)
diag["VISCODE"] = diag["VISCODE"].astype(str).str.lower()
bl = diag[(diag["VISCODE"]=="bl") & (diag["DIAGNOSIS"].isin([1,3]))].copy()
bl["Y"] = bl["DIAGNOSIS"].map({1:0, 3:1}).astype(int)
bl["PTID"] = bl["PTID"].astype(str).apply(norm_ptid)
y_tbl = bl.dropna(subset=["PTID"])[["PTID","Y"]].drop_duplicates(subset=["PTID"]).reset_index(drop=True)
print(f"[Y] n={len(y_tbl)} | counts -> 0={int((y_tbl['Y']==0).sum())}  1={int((y_tbl['Y']==1).sum())}")

# ---------- 2) Read PLINK batches & build common SNP matrix ----------
bfiles = [detect_bfile(d) for d in PLINK_DIRS]
print("[INFO] Detected PLINK basenames:")
for s in bfiles: print("  -", s)

common_snps = None
bims = []
blocks = []
all_ids = []

for i, stem in enumerate(bfiles):
    a, b, c = read_plink(str(stem))
    G, bim, fam = split_plink(a, b, c)  # G: (variants x samples) dask/xarray-like

    # Filter invalid
    bim = bim.copy()
    bim['chrom'] = bim['chrom'].astype(str)
    bim = bim[(bim['chrom']!='0') & (bim['pos'].astype(int) > 0)]
    snps = bim['snp'].astype(str).tolist()

    if common_snps is None:
        common_snps = snps
    else:
        s = set(common_snps).intersection(snps)
        common_snps = [x for x in common_snps if x in s]  # keep order of the first batch

    bims.append(bim)

if not common_snps:
    raise ValueError("No common SNPs across batches. Check build/coordinates.")

pd.Series(common_snps).to_csv(OUT/"markers.txt", index=False, header=False)
print(f"[INFO] markers.txt written: {len(common_snps)} SNPs")

# Assemble matrix per batch (samples x common_snps), stack vertically
for bim, stem in zip(bims, bfiles):
    snp2idx = {s:i for i,s in enumerate(bim['snp'].astype(str).tolist())}
    take = [snp2idx[s] for s in common_snps if s in snp2idx]
    if len(take) != len(common_snps):
        raise RuntimeError(f"[{stem.name}] missing some common SNPs unexpectedly.")

    a, b, c = read_plink(str(stem))
    G, bim2, fam = split_plink(a, b, c)

    sub = G[take, :].compute().astype(np.float32).T  # (samples x variants)

    # mean-impute NaNs per column
    if np.isnan(sub).any():
        col_means = np.nanmean(sub, axis=0)
        inds = np.where(np.isnan(sub))
        sub[inds] = np.take(col_means, inds[1])

    ids = fam['iid'].astype(str).apply(norm_ptid).tolist()
    all_ids.extend(ids)
    blocks.append(sub)

G_all = np.vstack(blocks)  # (n_samples_all x n_common_snps)
pd.Series(all_ids).to_csv(OUT/"samples_raw.txt", index=False, header=False)
print(f"[INFO] G_all shape: {G_all.shape} | samples_raw.txt n={len(all_ids)}")

# ---------- 3) Exon overlap via Gencode v19 ----------
gtf_path = ensure_gtf_plain(GTF)
gtf = pr.read_gtf(str(gtf_path))
feat_col = "feature" if "feature" in gtf.df.columns else ("Feature" if "Feature" in gtf.df.columns else None)
if feat_col is None:
    raise KeyError(f"GTF lacks 'feature' column. Got {list(gtf.df.columns)}")
exon = gtf[gtf.df[feat_col] == "exon"]

# Coordinates from FIRST batch (order = common_snps)
bim0 = bims[0].set_index("snp").loc[common_snps].reset_index()
df_snps = pd.DataFrame({
    "Chromosome": bim0["chrom"].apply(to_chr_label),
    "Start": bim0["pos"].astype(int),
    "End":   bim0["pos"].astype(int) + 1,
    "Name":  bim0["snp"].astype(str),
})
gr_snps = pr.PyRanges(df_snps)

ovl_df = gr_snps.join(exon).df
exon_names = set(ovl_df["Name"].astype(str))
is_exonic = df_snps["Name"].astype(str).isin(exon_names).to_numpy()
n_all = len(common_snps); n_ex = int(is_exonic.sum()); n_non = n_all - n_ex
print(f"[INFO] Exon tagging: total={n_all} | exonic={n_ex} ({n_ex/n_all:.2%}) | non-exonic={n_non} ({n_non/n_all:.2%})")

# Save marker aides
np.save(OUT/"is_exonic.npy", is_exonic.astype(bool))
df_snps.assign(is_exonic=is_exonic).to_csv(OUT/"markers.tsv", sep="\t", index=False)
with open(OUT/"exonMarkers.txt","w") as f:
    for rs in df_snps.loc[is_exonic, "Name"]: f.write(str(rs) + "\n")
with open(OUT/"nonExonMarkers.txt","w") as f:
    for rs in df_snps.loc[~is_exonic, "Name"]: f.write(str(rs) + "\n")

# ---------- 4) Split to X/C, clean NaN/Inf columns ----------
X_full = G_all[:, is_exonic]
C_full = G_all[:, ~is_exonic]
X_full, maskX = drop_nan_inf_cols(X_full, "X")
C_full, maskC = drop_nan_inf_cols(C_full, "C")

# Sync marker lists after cleaning
exon_markers_after = df_snps.loc[is_exonic, "Name"].astype(str).to_numpy()[maskX]
nonexon_markers_after = df_snps.loc[~is_exonic, "Name"].astype(str).to_numpy()[maskC]
pd.Series(exon_markers_after).to_csv(OUT/"exonMarkers_after_clean.txt", index=False, header=False)
pd.Series(nonexon_markers_after).to_csv(OUT/"nonExonMarkers_after_clean.txt", index=False, header=False)

# ---------- 5) Deduplicate PTIDs, align to baseline Y ----------
ids_raw = pd.Series(all_ids, name="PTID_raw").to_frame()
ids_raw["PTID"] = ids_raw["PTID_raw"].astype(str).apply(norm_ptid)
ids_raw = ids_raw.dropna(subset=["PTID"]).reset_index(drop=True)

dup_mask = ids_raw.duplicated(subset=["PTID"], keep="first")
keep_rows = (~dup_mask).to_numpy()
ids_dedup = ids_raw.loc[~dup_mask, ["PTID"]].reset_index(drop=True)

X_dedup = X_full[keep_rows, :]
C_dedup = C_full[keep_rows, :]

# left-join on PTID to get Y; keep only rows with Y
merged = ids_dedup.merge(y_tbl, on="PTID", how="left")
mask_have_y = merged["Y"].notna().to_numpy()

X = X_dedup[mask_have_y, :]
C = C_dedup[mask_have_y, :]
Y = merged.loc[mask_have_y, "Y"].to_numpy(dtype=int)
IDs = merged.loc[mask_have_y, "PTID"].to_numpy()

print(f"[ALIGN] X={X.shape}, C={C.shape}, Y={Y.shape}, IDs={IDs.shape}")

# ---------- 6) Build covariates [sex_male, gender_male] ----------
demo = pd.read_csv(DEMO_CSV)

# show duplicate columns info (visibility)
dup_info = demo.columns.to_series().groupby(demo.columns).size()
dups = dup_info[dup_info > 1]
if not dups.empty:
    print("[WARN] Duplicate column names detected in demographics:\n", dups)

# robustly pick PTID column
ptid_col = None
for c in ["PTID","ptid","Subject","subject","PTIDNUM"]:
    if c in demo.columns:
        ptid_col = c; break
if ptid_col is None:
    raise KeyError("No PTID-like column in ADNI_Demographics.csv")

# choose two sources (allow duplicates)
cands = ["SEX","Sex","PTGENDER","GENDER"]
sex_src = None; gender_src = None
for c in cands:
    if c in demo.columns:
        if sex_src is None: sex_src = c
        elif gender_src is None and c != sex_src: gender_src = c
if sex_src is None and gender_src is None:
    raise KeyError("No sex/gender columns found in demographics.")

if sex_src is None:    sex_src = gender_src
if gender_src is None: gender_src = sex_src

print(f"[INFO] Using PTID col: {ptid_col} | sex src: {sex_src} | gender src: {gender_src}")

tmp = demo[[ptid_col, sex_src, gender_src]].copy()
tmp[ptid_col] = tmp[ptid_col].astype(str).apply(norm_ptid)
tmp = tmp.dropna(subset=[ptid_col]).drop_duplicates(subset=[ptid_col]).reset_index(drop=True)

tmp["sex_male"]    = make_male_from_column(tmp, sex_src)
tmp["gender_male"] = make_male_from_column(tmp, gender_src)

cov = pd.DataFrame({"PTID": IDs}).merge(
    tmp[[ptid_col, "sex_male", "gender_male"]].rename(columns={ptid_col:"PTID"}),
    on="PTID", how="left"
)

# fill missing by mode (default 0)
for col in ["sex_male","gender_male"]:
    mode = cov[col].dropna().mode()
    fill = float(mode.iloc[0]) if not mode.empty else 0.0
    cov[col] = cov[col].astype(float).fillna(fill)

covariates = cov[["sex_male","gender_male"]].to_numpy(dtype=float)
# match old pipeline: L2 normalize columns
covariates = normalize(covariates, axis=0)

print(f"[COV] covariates shape: {covariates.shape}")

# ---------- 7) Save aligned outputs to Prepared_data2 ----------
np.save(OUT/"X.npy", X)
np.save(OUT/"C.npy", C)
np.save(OUT/"Y.npy", Y)
np.save(OUT/"covariates.npy", covariates)
pd.Series(IDs).to_csv(OUT/"ids.txt", index=False, header=False)

# helpful metadata
with open(OUT/"REPORT.txt","w") as f:
    f.write(f"X shape: {X.shape}\nC shape: {C.shape}\nY shape: {Y.shape}\n")
    f.write(f"covariates shape: {covariates.shape}\n")
    f.write(f"IDs n: {len(IDs)}\n")
    f.write(f"exonic kept: {X.shape[1]} | non-exonic kept: {C.shape[1]}\n")

print("\n[DONE] Saved aligned datasets to:", OUT)
print("       - X.npy, C.npy, Y.npy, covariates.npy, ids.txt")
print("       - markers.tsv, exonMarkers.txt, nonExonMarkers.txt, is_exonic.npy")
print("       - REPORT.txt")


[Y] n=874 | counts -> 0=608  1=266
[INFO] Detected PLINK basenames:
  - /Users/zhangjiahui/Desktop/Haohan Research/Alz GWAS data/ADNI1_GWAS_PLINK/ADNI_cluster_01_forward_757LONI
  - /Users/zhangjiahui/Desktop/Haohan Research/Alz GWAS data/ADNI_GO2_GWAS_PLINK/ADNI_GO_2_Forward_Bin
  - /Users/zhangjiahui/Desktop/Haohan Research/Alz GWAS data/ADNI_GO2_GWAS_PLINK2/ADNI_GO2_GWAS_2nd_orig_BIN
  - /Users/zhangjiahui/Desktop/Haohan Research/Alz GWAS data/ADNI3_GWAS_PLINK1/ADNI3_PLINK_Final
  - /Users/zhangjiahui/Desktop/Haohan Research/Alz GWAS data/ADNI3_GWAS_PLINK2/ADNI3_PLINK_FINAL_2nd


Mapping files: 100%|██████████| 3/3 [00:01<00:00,  2.55it/s]
Mapping files: 100%|██████████| 3/3 [00:00<00:00,  4.30it/s]
Mapping files: 100%|██████████| 3/3 [00:00<00:00,  5.98it/s]
Mapping files: 100%|██████████| 3/3 [00:00<00:00,  6.09it/s]
Mapping files: 100%|██████████| 3/3 [00:00<00:00,  7.77it/s]


[INFO] markers.txt written: 74886 SNPs


Mapping files: 100%|██████████| 3/3 [00:00<00:00,  9.27it/s]
  col_means = np.nanmean(sub, axis=0)
Mapping files: 100%|██████████| 3/3 [00:00<00:00,  7.57it/s]
  col_means = np.nanmean(sub, axis=0)
Mapping files: 100%|██████████| 3/3 [00:00<00:00,  7.34it/s]
Mapping files: 100%|██████████| 3/3 [00:00<00:00,  5.22it/s]
  col_means = np.nanmean(sub, axis=0)
Mapping files: 100%|██████████| 3/3 [00:00<00:00,  8.72it/s]
  col_means = np.nanmean(sub, axis=0)


[INFO] G_all shape: (2205, 74886) | samples_raw.txt n=2205


join: Strand data from other will be added as strand data to self.
If this is undesired use the flag apply_strand_suffix=False.


[INFO] Exon tagging: total=74886 | exonic=3262 (4.36%) | non-exonic=71624 (95.64%)
[CLEAN] X: removed 14 columns with NaN/Inf; kept 3248.
[CLEAN] C: removed 346 columns with NaN/Inf; kept 71278.
[ALIGN] X=(815, 3248), C=(815, 71278), Y=(815,), IDs=(815,)
[INFO] Using PTID col: PTID | sex src: PTGENDER | gender src: PTGENDER
[COV] covariates shape: (815, 2)

[DONE] Saved aligned datasets to: /Users/zhangjiahui/Desktop/Haohan Research/Alz GWAS data/Prepared_data4
       - X.npy, C.npy, Y.npy, covariates.npy, ids.txt
       - markers.tsv, exonMarkers.txt, nonExonMarkers.txt, is_exonic.npy
       - REPORT.txt


In [4]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
import pandas as pd

PREP = Path("/Users/zhangjiahui/Desktop/Haohan Research/Alz GWAS data/Prepared_data4")
C = np.load(PREP / "C.npy")

# 标准化 & 近零方差过滤
Z = StandardScaler(with_mean=True, with_std=True).fit_transform(C)
std = Z.std(axis=0)
keep = std > 1e-8
Z = Z[:, keep]
orig_cols = np.flatnonzero(keep)
print(f"[INFO] usable genes: {Z.shape[1]} / {C.shape[1]}")

# k=2 的 silhouette；最小簇≥10%
def sil_one(x):
    x = x.reshape(-1,1)
    lab = KMeans(n_clusters=2, n_init=10, random_state=0).fit_predict(x)
    if min(np.bincount(lab)) / len(lab) < 0.10:
        return np.nan
    return silhouette_score(x, lab, metric="euclidean")

scores = np.array([sil_one(Z[:, j]) for j in range(Z.shape[1])])

# 分布图
plt.figure(figsize=(6,4))
plt.hist(scores[np.isfinite(scores)], bins=60, edgecolor="white")
plt.xlabel("Silhouette score (k=2)"); plt.ylabel("Count")
plt.title("Distribution of per-gene Silhouette scores")
plt.tight_layout()
plt.savefig(PREP / "C_silhouette_distribution_k2.png", dpi=200)
plt.close()

# Top-3000
valid = np.where(np.isfinite(scores))[0]
topk = min(3000, valid.size)
top_idx_local = valid[np.argsort(scores[valid])[::-1][:topk]]
top_idx_orig = orig_cols[top_idx_local]

C_top = C[:, top_idx_orig]
np.save(PREP / "C_silhouette_top3000.npy", C_top)
np.save(PREP / "C_silhouette_top3000_colidx.npy", top_idx_orig)
print(f"[SAVE] C_silhouette_top3000.npy shape={C_top.shape}")

# 可选：输出对应的标记名
markers_path = PREP / "nonExonMarkers_after_clean.txt"
if markers_path.exists():
    mk = pd.read_csv(markers_path, header=None)[0].astype(str)
    if len(mk) == C.shape[1]:
        mk_sel = mk.iloc[top_idx_orig].reset_index(drop=True)
        mk_sel.to_csv(PREP / "C_silhouette_top3000_markers.txt", index=False, header=False)
        print("[SAVE] C_silhouette_top3000_markers.txt")


[INFO] usable genes: 71272 / 71278
[SAVE] C_silhouette_top3000.npy shape=(815, 3000)
[SAVE] C_silhouette_top3000_markers.txt
