In [None]:
import numpy as np
import pandas as pd
import xarray as xr
from ALLCools.motif.cistarget import cistarget_motif_enrichment
from scipy.sparse import load_npz


In [None]:
group_name = "CTX"

In [3]:
region = group_name

In [4]:
motif_ds = xr.open_zarr(f"{region}_motif.ds")
motif_ds["rank"].load()
total_dmr = motif_ds.get_index("dmr_id")
motif_id = motif_ds.get_index("motif")
print(motif_ds.dims)

In [5]:
dmr_map = pd.read_hdf(
    "/home/jzhou_salk_edu/sky_workdir/230209_dmrmotif/dmr_idmap.hdf", key="data"
)
dmr_status = load_npz(f"{region}_hypo.npz")
dmr_id = np.load(f"{region}_index.npz", allow_pickle=True)["dmr"]
cluster_id = np.load(f"{region}_index.npz", allow_pickle=True)["cluster"]
cluster_id = (
    pd.Series(cluster_id, name="cluster_id")
    .reset_index()
    .set_index("cluster_id")["index"]
)
print(dmr_status.shape)


In [6]:
auc_df, nes_df = [], []
for i, cluster in enumerate(cluster_id.index):
    seldmr = pd.Index(dmr_id[dmr_status[i].indices])
    seldmr = total_dmr.intersection(seldmr.map(dmr_map))
    motif_df = motif_ds.sel({"dmr_id": seldmr})["rank"].to_pandas()
    motif_enrichment, motif_hit, full_stats = cistarget_motif_enrichment(
        rank_df=motif_df,
        total_regions=total_dmr.size,
        auc_threshold=0.005,
        nes_threshold=3,
        rank_threshold=0.05,
        full_motif_stats=True,
    )
    motif_result = xr.Dataset({"hits": motif_hit})
    motif_result["hits"].encoding = {"chunks": (50, 1000000)}
    motif_result.to_zarr(f"{cluster}.enriched_motif_dmr.zarr", mode="w")
    auc_df.append(full_stats["AUC"])
    nes_df.append(full_stats["NES"])
    print(cluster)
    

CTX-0
CTX-1
CTX-10
CTX-11
CTX-12
CTX-13
CTX-14
CTX-15
CTX-16
CTX-17
CTX-18
CTX-19
CTX-2
CTX-20
CTX-21
CTX-22
CTX-23
CTX-24
CTX-25
CTX-26
CTX-27
CTX-28
CTX-29
CTX-3
CTX-30
CTX-31
CTX-32
CTX-33
CTX-34
CTX-35
CTX-36
CTX-37
CTX-38
CTX-39
CTX-4
CTX-40
CTX-41
CTX-42
CTX-43
CTX-44
CTX-45
CTX-46
CTX-47
CTX-48
CTX-49
CTX-5
CTX-50
CTX-51
CTX-52
CTX-53
CTX-54
CTX-55
CTX-6
CTX-7
CTX-8
CTX-9


In [10]:
auc_df = pd.concat(auc_df, axis=1)
auc_df.columns = cluster_id.index
nes_df = pd.concat(nes_df, axis=1)
nes_df.columns = cluster_id.index
auc_df.to_hdf(f"{region}_AUC.hdf", key="data")
nes_df.to_hdf(f"{region}_NES.hdf", key="data")
