In [1]:
""" parameters """

# data
data_fold_cpu = "/mnt/efs_v2/tag_onc/users/tianrui.qi/TCGA-Onc/data/"
data_fold_gpu = "data/"
# parameter
chr = "6"
pval_thresh = 1e-5
# method
umap_fold = "ckpt/umap/"
umap_ckpt = "01"    # {umap_fold}/{umap_ckpt}.sav

In [1]:
""" import """

import torch
import transformers

import cuml
import numpy as np
import pandas as pd

import os
import joblib
import warnings
import seaborn as sns
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
import matplotlib.patches

transformers.logging.set_verbosity_error()
warnings.filterwarnings("ignore", message="Unable to import Triton*")

In [None]:
""" profile """

# dataset
profile = pd.read_csv(os.path.join(data_fold_cpu, "profile.txt"))
# Isolate string (i.e. su001) -> int (i.e. 1)
profile["Isolate"] = profile["Isolate"].apply(lambda x: int(x[2:]))
# Treatment pre/post -> 0/1
profile["Treatment"] = profile["Treatment"].apply(lambda x: int(not "pre" in x))
# Sort by Isolate (1 to 8), Treatment (pre to post), and Tissue (normal to BCC)
profile = profile.sort_values(
    by=["Isolate", "Treatment", "Tissue"], ascending=[True, True, False]
).reset_index(drop=True)

In [None]:
""" DNABERT2 Embedding """

def getEmbedding(csv: pd.DataFrame, batch_size: int) -> np.ndarray:
    """
    Calculate embedding of all sequences in the csv. 

    Args:
        csv: pd.DataFrame, check README.md for the format.
        batch_size: int, batch size for the calculation.

    Returns:
        embedding: np.ndarray, the embedding of all sequences in the csv with
            shape (len(csv), 768).
    """
    # model for embedding
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        "zhihan1996/DNABERT-2-117M", trust_remote_code=True
    )
    dnabert2 = transformers.AutoModel.from_pretrained(
        "zhihan1996/DNABERT-2-117M", trust_remote_code=True
    ).to(device).eval()

    embedding = None
    for i in range(int(np.ceil(len(csv) / batch_size))):
        # sequence
        sequence_batch = csv["sequence"].iloc[
            i*batch_size:(i+1)*batch_size
        ].to_list()
        # token
        token_batch = tokenizer(
            sequence_batch, return_tensors = 'pt', padding=True
        )["input_ids"].to(device)
        # embedding
        with torch.no_grad(): 
            embedding_batch = torch.mean(
                dnabert2(token_batch)[0], dim=1
            ).detach().cpu().numpy()
        # save
        embedding = np.concatenate(
            [embedding, embedding_batch], axis=0
        ) if embedding is not None else embedding_batch
    return embedding

embedding_dnabert2: dict[str, np.ndarray] = {run: None for run in profile["Run"]}
for run in tqdm.tqdm(profile["Run"], smoothing=0.0, unit="run"):
    # load the csv, filter reads that pass the pval_thresh
    csv = pd.read_csv(os.path.join(data_fold_cpu, "csv", f"{run}/{chr}.csv"))
    csv = csv[csv[str(pval_thresh)]>=1]
    # calculate the embedding by batch
    embedding_dnabert2[run] = getEmbedding(csv, batch_size=int(1e2))

In [None]:
""" UMAP """

umap_path = os.path.join(umap_fold, f"{umap_ckpt}.sav")
if os.path.exists(umap_path):
    reducer = joblib.load(umap_path)
else:
    embedding = np.concatenate([embedding_dnabert2[run] for run in profile["Run"]], axis=0)
    reducer = cuml.UMAP(n_components=2)
    reducer.fit(embedding)
    joblib.dump(reducer, umap_path)

In [None]:
""" UMAP Embedding """

embedding_umap = {}
embedding_umap_save_fold = os.path.join(data_fold_gpu, f"dnabert2_umap{umap_ckpt}/chr{chr}_pval{int(-np.log10(pval_thresh))}")
if not os.path.exists(embedding_umap_save_fold):
    os.makedirs(embedding_umap_save_fold)
for run in tqdm.tqdm(profile["Run"], smoothing=0.0, unit="run"):
    embedding_umap_save_path = os.path.join(embedding_umap_save_fold, f"{run}.csv")
    if os.path.exists(embedding_umap_save_path):
        embedding_umap[run] = np.loadtxt(embedding_umap_save_path, delimiter=",")
    else:
        embedding_umap[run] = reducer.transform(embedding_dnabert2[run])
        np.savetxt(embedding_umap_save_path, embedding_umap[run], delimiter=",")

In [None]:
""" hexbin map for all sample """

fig, axs = plt.subplots(
    4, 6, figsize=(30, 20), sharex=True, sharey=True, dpi=500
)
index = -1
for i, ax in enumerate(axs.flat):
    if i == 11:
        ax.axis("off")
        continue
    index += 1

    ax.hexbin(
        embedding_umap[profile["Run"][index]][:, 0], 
        embedding_umap[profile["Run"][index]][:, 1], 
        gridsize=150,
        cmap="Reds",
        vmin=2, vmax=80,
    )

    for spine in ax.spines.values():
        spine.set_color("r" if profile["Treatment"][index] == 0 else "b")
        spine.set_linewidth(2)

    ax.set_aspect("equal")
    ax.set_xlim(-16, 16)
    ax.set_ylim(-16, 16)
    ax.set_title(
        "Run: {}; ".format(profile["Run"][index]) + 
        "Isolate: {};".format(profile["Isolate"][index])
    )
fig.legend(
    handles=[
        matplotlib.patches.Patch(color="r", label="pre "), 
        matplotlib.patches.Patch(color="b", label="post")
    ], loc="upper right", ncol=2, fontsize=14,
)
fig.suptitle(
    f"Stanford Data; chromosome {chr}; SNPs p-val threshold (<=) {pval_thresh}", 
    fontweight='bold', y=0.992, fontsize=16
)
fig.supxlabel("UMAP1", fontweight='bold', y=0.005, fontsize=16)
fig.supylabel("UMAP2", fontweight='bold', x=0.010, fontsize=16)
fig.tight_layout()
fig.savefig(
    os.path.join(embedding_umap_save_fold, "hexbin.png"), 
    dpi=500,
)
plt.close()