In [None]:
import torch
import transformers

import cuml
import umap
import cupy as cp
import numpy as np
import pandas as pd

import os
import joblib
import warnings
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
import matplotlib.patches

transformers.logging.set_verbosity_error()
warnings.filterwarnings("ignore", message="Unable to import Triton*")

# path
data_fold = "/mnt/efs_v2/tag_onc/users/tianrui.qi/TCGA-Onc/data/public/"
# parameter
chr = "6"
pval_thresh = 1e-5

# dataset
profile = pd.read_csv(os.path.join(data_fold, "profile.txt"))
# Isolate string (i.e. su001) -> int (i.e. 1)
profile["Isolate"] = profile["Isolate"].apply(lambda x: int(x[2:]))
# Treatment pre/post -> 0/1
profile["Treatment"] = profile["Treatment"].apply(lambda x: int(not "pre" in x))

# model for embedding
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "zhihan1996/DNABERT-2-117M", trust_remote_code=True
)
dnabert2 = transformers.AutoModel.from_pretrained(
    "zhihan1996/DNABERT-2-117M", trust_remote_code=True
).to(device).eval()

In [None]:
def getEmbedding(csv: pd.DataFrame, batch_size: int) -> np.ndarray:
    """
    Calculate embedding of all sequences in the csv. 

    Args:
        csv: pd.DataFrame, check README.md for the format.
        batch_size: int, batch size for the calculation.

    Returns:
        embedding: np.ndarray, the embedding of all sequences in the csv with
            shape (len(csv), 768).
    """
    embedding = None
    for i in range(int(np.ceil(len(csv) / batch_size))):
        # sequence
        sequence_batch = csv["sequence"].iloc[
            i*batch_size:(i+1)*batch_size
        ].to_list()
        # token
        token_batch = tokenizer(
            sequence_batch, return_tensors = 'pt', padding=True
        )["input_ids"].to(device)
        # embedding
        with torch.no_grad(): 
            embedding_batch = torch.mean(
                dnabert2(token_batch)[0], dim=1
            ).detach().cpu().numpy()
        # save
        embedding = np.concatenate(
            [embedding, embedding_batch], axis=0
        ) if embedding is not None else embedding_batch
    return embedding

embedding_dict: dict[str, np.ndarray] = {run: None for run in profile["Run"]}
for run in tqdm.tqdm(profile["Run"], smoothing=0.0, unit="run"):
    # load the csv, filter reads that pass the pval_thresh
    csv = pd.read_csv(os.path.join(data_fold, "csv", f"{run}/{chr}.csv"))
    csv = csv[csv[str(pval_thresh)]>=1]
    # calculate the embedding by batch
    embedding_dict[run] = getEmbedding(csv, batch_size=int(1e2))

In [None]:
embedding = np.concatenate([embedding_dict[run] for run in profile["Run"]], axis=0)
reducer = cuml.UMAP(n_components=2)
embedding = reducer.fit(embedding)
joblib.dump(reducer, "umap.sav")

In [None]:
reducer = joblib.load("umap.sav")
embedding = reducer.transform(embedding_dict[profile["Run"].iloc[0]])

In [None]:
color = np.concatenate([
    np.ones(len(embedding_dict[profile["Run"][i]])) * profile["Treatment"][i]%2 
    for i in range(len(profile))
], axis=0).reshape(-1, 1)
b_patch = matplotlib.patches.Patch(color="r", label="pre ")
r_patch = matplotlib.patches.Patch(color="b", label="post")

# concate embedding and color to shuffle
x = np.concatenate([embedding, color], axis=1)
x = x[np.random.permutation(np.arange(len(x))), :]

color = np.vectorize(lambda x: {0: "r", 1: "b"}[x])(x[:, 2])

In [None]:
dpi = 500
figsize = (8, 8)

plt.figure(dpi=dpi, figsize=figsize)
plt.scatter(x[:, 0], x[:, 1], c=color,s=0.0001, marker="o")
plt.xlim([-25, 25])
plt.ylim([-25, 25])
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.title(f"chromosome {chr}; pval={pval_thresh}")
plt.legend(handles=[b_patch, r_patch])
plt.savefig("PrePost.png", dpi=dpi)

plt.figure(dpi=dpi, figsize=figsize)
plt.scatter(x[color=="r"][:, 0], x[color=="r"][:, 1], c="r", s=0.0001, marker="o")
plt.xlim([-25, 25])
plt.ylim([-25, 25])
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.title(f"chromosome {chr}; pval={pval_thresh}")
plt.legend(handles=[b_patch, r_patch])
plt.savefig("Pre.png", dpi=dpi)

plt.figure(dpi=dpi, figsize=figsize)
plt.scatter(x[color=="b"][:, 0], x[color=="b"][:, 1], c="b",s=0.0001, marker="o")
plt.xlim([-25, 25])
plt.ylim([-25, 25])
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.title(f"chromosome {chr}; pval={pval_thresh}")
plt.legend(handles=[b_patch, r_patch])
plt.savefig("Post.png", dpi=dpi)