In [None]:
import torch
import transformers
import cuml

import cupy as cp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches

import os
import warnings
import tqdm.notebook as tqdm

transformers.logging.set_verbosity_error()
warnings.filterwarnings("ignore", message="Unable to import Triton*")

# path
data_fold_cpu = "/mnt/efs_v2/tag_onc/users/tianrui.qi/TCGA-Onc/data/public/"
data_fold_gpu = "/mnt/efs_v2/dbgap_tcga/users/tianrui.qi/TCGA-Onc/data/public/"
csv_load_fold_cpu = os.path.join(data_fold_cpu, "csv")
csv_load_fold_gpu = os.path.join(data_fold_gpu, "csv")

# dataset
profile = pd.read_csv(os.path.join(data_fold_cpu, "profile.txt"))
# Isolate string (i.e. su001) -> int (i.e. 1)
profile["Isolate"] = profile["Isolate"].apply(lambda x: int(x[2:]))
# Treatment pre/post -> 0/1
profile["Treatment"] = profile["Treatment"].apply(lambda x: int(not "pre" in x))

# parameter
chr = "1"
pval_thresh = 1e-4
batch_size = int(1e2)

In [None]:
# model for embedding
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "zhihan1996/DNABERT-2-117M", trust_remote_code=True
)
dnabert2 = transformers.AutoModel.from_pretrained(
    "zhihan1996/DNABERT-2-117M", trust_remote_code=True
).to(device).eval()

chr_list = [str(i) for i in range(1, 23)] + ["X"]   # BAM naming convention

for run in tqdm.tqdm(profile["Run"], smoothing=0.0, unit="run"):
    # load the csv, try gpu first then cpu
    if os.path.exists(os.path.join(csv_load_fold_gpu, f"{run}/{chr}.csv")):
        csv = pd.read_csv(os.path.join(csv_load_fold_gpu, f"{run}/{chr}.csv"))
    else:
        csv = pd.read_csv(os.path.join(csv_load_fold_cpu, f"{run}/{chr}.csv"))
    if "embedding" not in csv.columns:
        csv["embedding"] = None

    # get the index of reads that pass the pval_thresh and embedding is None
    temp = csv
    temp = temp[temp[str(pval_thresh)]>=1]
    temp = temp[pd.isnull(temp["embedding"])]
    index = temp.index

    # all embedding is calculated, skip
    if len(index) == 0:
        continue

    # calculate the embedding
    for i in tqdm.tqdm(
        range(int(np.ceil(len(index) / batch_size))), 
        smoothing=0.0, desc=run, unit="batch", leave=False
    ):  
        # sequence
        index_batch = index[i*batch_size:(i+1)*batch_size]
        sequence_batch = csv["sequence"].loc[index_batch].to_list()
        # token
        token_batch = tokenizer(
            sequence_batch, return_tensors = 'pt', padding=True
        )["input_ids"].to(device)
        # embedding
        with torch.no_grad(): 
            embedding_batch = torch.mean(
                dnabert2(token_batch)[0], dim=1
            ).detach().cpu().numpy()
        # update csv dataframe
        for i in range(len(index_batch)):
            csv.at[index_batch[i], "embedding"] = embedding_batch[i]

    # save to gpu fold so that we can skip the calculation next time
    if not os.path.exists(os.path.join(csv_load_fold_gpu, f"{run}")): 
        os.makedirs(os.path.join(csv_load_fold_gpu, f"{run}"))
    csv.to_csv(os.path.join(csv_load_fold_gpu, f"{run}/{chr}.csv"), index=False)

In [None]:
embedding_dict = {}
for run in tqdm.tqdm(profile["Run"], smoothing=0.0, unit="run"):
    csv = pd.read_csv(os.path.join(csv_load_fold_gpu, f"{run}/{chr}.csv"))
    index = csv[csv[str(pval_thresh)]>=1].index
    embedding_dict[run] = csv["embedding"].loc[index].to_numpy()

In [None]:
embedding = cp.array(
    np.concatenate([embedding_dict[run] for run in profile["Run"]], axis=0)
).astype(cp.float32)
embedding = cp.asnumpy(
    cuml.UMAP(n_components=2, verbose=True).fit_transform(embedding)
)

In [None]:
color = np.concatenate([
    np.ones(len(embedding_dict[profile["run"][i]])) * profile["Treatment"][i]%2 
    for i in range(len(profile))
], axis=0).reshape(-1, 1)
b_patch = matplotlib.patches.Patch(color="r", label="pre ")
r_patch = matplotlib.patches.Patch(color="b", label="post")

# concate embedding and color to shuffle
x = np.concatenate([embedding, color], axis=1)
x = x[np.random.permutation(np.arange(len(x))), :]

color = np.vectorize(lambda x: {0: "r", 1: "b"}[x])(x[:, 2])

In [None]:
plt.figure(dpi=1000, figsize=(8, 8))
plt.scatter(
    x[:, 0], x[:, 1], c=color,
    s=0.0001, marker="o"
)

plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.title(f"chromosome {chr}; pval={pval_thresh}")
plt.legend(handles=[b_patch, r_patch])
plt.show()

plt.figure(dpi=1000, figsize=(8, 8))
plt.scatter(
    x[color=="r"][:, 0], x[color=="r"][:, 1], c="r",
    s=0.0001, marker="o"
)
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.title(f"chromosome {chr}; pval={pval_thresh}")
plt.legend(handles=[b_patch, r_patch])
plt.show()

plt.figure(dpi=1000, figsize=(8, 8))
plt.scatter(
    x[color=="b"][:, 0], x[color=="b"][:, 1], c="b",
    s=0.0001, marker="o"
)
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.title(f"chromosome {chr}; pval={pval_thresh}")
plt.legend(handles=[b_patch, r_patch])
plt.show()