In [None]:
data_fold = "data/stanford/"
method = "pretrain"     # pretrain or finetune
isolates = [3, 7]       # for get the selected pos
chromosome = "6"
pval_thresh = 1e-5
part = "top"            # top, low
k_reads = 4000         # 2000, 4000, 12000
pos_range = 64

In [None]:
import torch

import umap
import numpy as np
import pandas as pd
import sklearn.decomposition

import os
import seaborn as sns
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
import matplotlib.patches

""" function """

def getSelectedPos(
    embd: dict[str, np.ndarray], part: str, k_reads: int
) -> np.ndarray:
    runs = list(embd.keys())

    # get filter of reads of each sample by Euclidean distance matrix
    filter = [[
        np.full(len(embd[runs[i]]), False, dtype=bool) for _ in range(len(runs))
    ] for i in range(len(runs))]
    for i in range(len(runs)):
        for j in range(i+1, len(runs)):
            x1embd = embd[runs[i]][:, :768]                         # (Ni, 768)
            x2embd = embd[runs[j]][:, :768]                         # (Nj, 768)
            x1pos  = embd[runs[i]][:,  768]                         # (Ni, )
            x2pos  = embd[runs[j]][:,  768]                         # (Nj, )

            distance_matrix = torch.cdist(                          # (Ni, Nj)
                torch.tensor(x1embd), torch.tensor(x2embd)
            ).numpy()

            # concate into DataFrame to keep track index
            x1df = pd.DataFrame(
                {"distance": np.max(distance_matrix, axis=1), "pos": x1pos}
            )
            x2df = pd.DataFrame(
                {"distance": np.max(distance_matrix, axis=0), "pos": x2pos}
            )

            if part == "top":   # sort from distance high to low
                x1df = x1df.sort_values(by="distance", ascending=False)
                x2df = x2df.sort_values(by="distance", ascending=False)
            if part == "low":   # sort from distance low to high
                x1df = x1df.sort_values(by="distance", ascending=True)
                x2df = x2df.sort_values(by="distance", ascending=True)

            # drop duplicates reads with same position
            x1df = x1df.drop_duplicates(subset="pos", keep="first")
            x2df = x2df.drop_duplicates(subset="pos", keep="first")

            # keep k_reads of head, get there index that match original embd, 
            # transfer index to one hot filter
            filter[i][j][x1df.head(k_reads).index] = True           # (Ni, )
            filter[j][i][x2df.head(k_reads).index] = True           # (Nj, )
    filter = {runs[i]: np.any(filter[i], axis=0) for i in range(len(runs))}

    # get pos that need to be selected
    # for each run, get the pos of reads that are selected
    pos = {run: embd[run][filter[run]][:, 768] for run in runs}     # {run: (~6000, )}
    # combine, remove duplicates, and sort pos of all runs
    pos = np.sort(np.unique(np.concatenate([pos[run] for run in runs])))    # (~13000, )

    return pos  # (~13000, )

def getSelectedEmbd(
    embd: np.ndarray, pos: np.ndarray, pos_range: int
) -> np.ndarray:
    embd_selected = np.zeros([len(pos), 768])   # [~13000, 768]
    e = 0   # index for embed
    for p in range(len(pos)):
        # move e to the first read that is larger than pos[p]-pos_range
        # note that both embd and pos are sorted, so for next p, we can directly 
        # start from previous e
        while e < len(embd) and embd[e, 768] < pos[p]-pos_range: e += 1
        # if first read that larger than pos[p] - pos_range is also larger than 
        # pos[p] + pos_range, then we have no read in this range, set the embd of
        # this pos to 0; 
        if embd[e, 768] > pos[p]+pos_range: continue
        # if first read that larger than pos[p] - pos_range smaller than 
        # pos[p] + pos_range, get the reads that closest to pos[p]
        e_temp = e
        distance = pos_range
        while e_temp < len(embd) and embd[e_temp, 768] <= pos[p]+pos_range:
            if abs(embd[e_temp, 768] - pos[p]) <= distance:
                distance = abs(embd[e_temp, 768] - pos[p])
                embd_selected[p] = embd[e_temp, :768]
            e_temp += 1
    return embd_selected    # [~13000, 768]

def pearsonCorrelation(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
    x1centered = x1 - x1.mean(axis=1, keepdims=True)        # (N1, 768)
    x2centered = x2 - x2.mean(axis=1, keepdims=True)        # (N2, 768)
    numerator = np.dot(x1centered, x2centered.T)            # (N1, N2)
    x1var = np.sum(x1centered**2, axis=1, keepdims=True)    # (N1,  1)
    x2var = np.sum(x2centered**2, axis=1, keepdims=True)    # (N2,  1)
    denominator = np.sqrt(np.dot(x1var, x2var.T))           # (N1, N2)
    return numerator / denominator                          # (N1, N2)

""" profile """

# dataset
profile = pd.read_csv(os.path.join(data_fold, "profile.txt"))
# Isolate string (i.e. su001) -> int (i.e. 1)
profile["Isolate"] = profile["Isolate"].apply(lambda x: int(x[2:]))
# Treatment pre/post -> 0/1
profile["Treatment"] = profile["Treatment"].apply(lambda x: int(not "pre" in x))
# Sort by Isolate (1 to 8), Treatment (pre to post), and Tissue (normal to BCC)
profile = profile.sort_values(
    by=["Isolate", "Treatment", "Tissue"], ascending=[True, True, False]
).reset_index(drop=True)
# only keep run, isolate, treatment, tissue
profile = profile[["Run", "Isolate", "Treatment", "Tissue"]]

In [None]:
import selector
import importlib
importlib.reload(selector)

feature = selector.Selector("data/feature", pval_thresh=1e-5, ascending=False)

In [None]:
""" plot correlation heat map of runs for selected pos"""

# runs used for get selected pos
runs = []
for isolate in isolates:
    runs += profile[profile["Isolate"] == isolate]["Run"].to_list()

# embd after snps and pos filter, len(runs) * [~13000, 768]
embd = {}
for run in runs:
    embd[run] = np.load(os.path.join(data_fold, f"embd-{method}/{run}/{chromosome}.npy"))
    # [~40k, 776]
    embd[run] = embd[run][embd[run][:, 769-int(np.log10(pval_thresh))]>=1, :] 
    # [~13000, 768]
    embd[run] = getSelectedEmbd(embd[run], pos, pos_range)

# pca of embd, use pca1 to represent each run
# len(runs) * [~13000, 768] -> [len(runs), ~13000]
pca_reducer = sklearn.decomposition.PCA(n_components=1)
pca_reducer.fit(np.concatenate([embd[run][:, :768] for run in runs], axis=0))
pca = np.array([pca_reducer.transform(embd[run][:, :768]) for run in runs])[:, :, 0]

# calculate the correlation between each pair of runs
heat_map = pearsonCorrelation(pca, pca)     # [len(runs), len(runs)]
for i in range(len(heat_map)): heat_map[i, i] = np.nan

# plot the heat map 6*6 
labels = []
for run in runs:
    id, isolate, treatment, tissue = profile[profile["Run"] == run].iloc[0]
    labels.append("{}-{}-{}-{}".format(
            id, isolate,
            "Pre" if treatment == 0 else "Post", 
            "Normal" if "normal" in tissue else "Tumor"
        ))
plt.figure(dpi=300)
sns.heatmap(
    heat_map, annot=True, fmt=".3f", cmap="coolwarm", square=True,
    yticklabels=labels, xticklabels=False
)
plt.title("Pearson Correlation", fontweight='bold')
plt.show()

In [None]:
""" plot scatter of PCA1 and PCA2 of all runs """

# runs we will plot
runs = profile["Run"].to_list()

# embd after snps and pos filter, len(runs) * [~13000, 768]
embd = {}
for run in runs:
    embd[run] = np.load(os.path.join(data_fold, f"embd-{method}/{run}/{chromosome}.npy"))
    # [~40k, 776]
    embd[run] = embd[run][embd[run][:, 769-int(np.log10(pval_thresh))]>=1, :] 
    # [~13000, 768]
    embd[run] = getSelectedEmbd(embd[run], pos, pos_range)

# pca of embd, use pca1 to represent each run
# len(runs) * [~13000, 768] -> [len(runs), ~13000]
pca_reducer = sklearn.decomposition.PCA(n_components=5)
pca_reducer.fit(np.concatenate([embd[run][:, :768] for run in runs], axis=0))
pca = np.array([pca_reducer.transform(embd[run][:, :768]) for run in runs])[:, :, 0]

print(pca_reducer.explained_variance_ratio_)

# [len(runs), ~13000] -> [len(runs), 2], for plot
pca = sklearn.decomposition.PCA(n_components=2).fit_transform(pca)

# plot scatter of PCA1 and PCA2
labels = []
for run in runs:
    id, isolate, treatment, tissue = profile[profile["Run"] == run].iloc[0]
    labels.append("{}-{}".format(id, isolate,))

fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True, dpi=300)
# color by isolate
colors = [f"C{i}" for i in range(8)]
axs[0].scatter(pca[:, 0], pca[:, 1], c=[colors[profile[profile["Run"] == run].iloc[0]["Isolate"]-1] for run in runs])
for i in range(len(runs)):
    axs[0].text(pca[i, 0], pca[i, 1], labels[i], fontsize=3)
axs[0].legend(
    handles=[
        matplotlib.patches.Patch(color=colors[i], label=f"Isolate {i+1}") for i in range(8)
    ], loc="upper right", ncol=1,
)
# color by treatment
colors = ["red" if profile[profile["Run"] == run].iloc[0]["Treatment"] == 0 else "blue" for run in runs]
axs[1].scatter(pca[:, 0], pca[:, 1], c=colors)
for i in range(len(runs)):
    axs[1].text(pca[i, 0], pca[i, 1], labels[i], fontsize=3)
axs[1].legend(
    handles=[
        matplotlib.patches.Patch(color="r", label="pre "), 
        matplotlib.patches.Patch(color="b", label="post")
    ], loc="upper right", ncol=1,
)
# color by tissue
colors = ["green" if "normal" in profile[profile["Run"] == run].iloc[0]["Tissue"] else "purple" for run in runs]
axs[2].scatter(pca[:, 0], pca[:, 1], c=colors)
for i in range(len(runs)):
    axs[2].text(pca[i, 0], pca[i, 1], labels[i], fontsize=3)
axs[2].legend(
    handles=[
        matplotlib.patches.Patch(color="g", label="normal"), 
        matplotlib.patches.Patch(color="purple", label="BCC")
    ], loc="upper right", ncol=1,
)
fig.suptitle(f"Selected Pos from isolates {isolates}", fontweight='bold')
fig.supxlabel("PCA1", fontweight='bold')
fig.supylabel("PCA2", fontweight='bold')
fig.tight_layout()
fig.show()

In [None]:
""" plot Distribution of Euclidean Distance """

bins = 100
fontsize = 14

# runs used for get selected pos
runs = []
for isolate in isolates:
    runs += profile[profile["Isolate"] == isolate]["Run"].to_list()

temp_embd = {}
for run in runs:
    temp_embd[run] = np.load(os.path.join(data_fold, f"embd-{method}/{run}/{chromosome}.npy"))
    temp_embd[run] = temp_embd[run][temp_embd[run][:, 769-int(np.log10(pval_thresh))]>=1, :]

x1embd = temp_embd[runs[0]][:, :768]     # (N1, 768)
x2embd = temp_embd[runs[3]][:, :768]     # (N2, 768)
distance_matrix = torch.cdist(      # (N1, N2)
    torch.tensor(x1embd), torch.tensor(x2embd)
).numpy()
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
for i in range(2):
    axs[i].hist(np.max(distance_matrix, axis=int(not i)), bins=bins)
    # title
    id, isolate, treatment, tissue = profile[profile["Run"] == runs[i]].iloc[0]
    axs[i].set_title(
        "{}-{}-{}-{}".format(
            id, isolate, 
            "Pre" if treatment == 0 else "Post", 
            "Normal" if "normal" in tissue else "Tumor"
        ), 
        fontweight='bold', fontsize=fontsize
    )
    # axis
    axs[i].set_xlim(5, 11)
fig.supxlabel("Euclidean Distance", fontweight='bold', y=0.004, fontsize=fontsize)
fig.supylabel("Number of Reads", fontweight='bold', x=0.008, fontsize=fontsize)
fig.tight_layout()
fig.show()

In [None]:
""" Plot Hexbin N * N+1"""

"""
gridsize = 80
vmin, vmax = 4, 64
fontsize = 18

fig, axs = plt.subplots(
    len(runs), len(runs)+1, figsize=(35, 30), 
    sharex=True, sharey=True
)
# plot
for i in range(len(runs)):
    # plot umpa of sample i before filter in diagonal
    axs[i, i].hexbin(
        umap[runs[i]][:, 0], umap[runs[i]][:, 1], 
        cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax,
    )
    # plot umap of (i, :) and (:, i) after filter by distance (i, j)
    for j in range(i+1, len(runs)):
        # umap of (i, :)
        axs[i, j].hexbin(
            umap[runs[i]][:, 0], umap[runs[i]][:, 1], 
            cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax/2,
        )
        # umap of (:, i)
        axs[j, i].hexbin(
            umap[runs[j]][:, 0], umap[runs[j]][:, 1], 
            cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax/2,
        )
    # plot umap of (i, len(runs)) of combine all filter
    axs[i, len(runs)].hexbin(
        umap[runs[i]][:, 0], 
        umap[runs[i]][:, 1], 
        cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax/2,
    )
# set up axis and label
for i in range(len(runs)):
    id, isolate, treatment, tissue = profile[profile["Run"] == runs[i]].iloc[0]
    treatment = "Pre" if treatment == 0 else "Post"
    tissue = "Normal" if "normal" in tissue else "Tumor"
    info = f"{id} - {isolate} - {treatment} - {tissue}"
    axs[0,  i].set_xlabel(info, fontsize=fontsize, fontweight='bold')
    axs[0,  i].xaxis.set_label_position("top")
    axs[0, len(runs)].set_xlabel("All Selected Reads", fontsize=fontsize, fontweight='bold')
    axs[0, len(runs)].xaxis.set_label_position("top")
    axs[i, -1].set_ylabel(info, fontsize=fontsize, fontweight='bold')
    axs[i, -1].yaxis.set_label_position("right")
    for j in range(len(runs)+1):
        axs[i, j].set_aspect("equal")
        axs[i, j].set_xlim(-22, 22)
        axs[i, j].set_ylim(-22, 22)
fig.supxlabel("UMAP1", fontsize=fontsize, fontweight='bold', y=0.004)
fig.supylabel("UMAP2", fontsize=fontsize, fontweight='bold', x=0.008)
fig.tight_layout()
fig.savefig("temp.png", dpi=500)
fig.show()
"""

In [None]:
""" Plot Hexbin of all samples """

"""
gridsize = 80
vmin, vmax = 4, 64
fontsize = 18

# hexbin, split by isolate
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
umap_isolates = [None, None]
for i in range(len(runs)):
    axs_index = 0 if profile[profile["Run"] == runs[i]]["Isolate"].values[0] == 3 else 1
    if umap_isolates[axs_index] is None:
        umap_isolates[axs_index] = umap[runs[i]]
    else:
        umap_isolates[axs_index] = np.concatenate(
            [umap_isolates[axs_index], umap[runs[i]]]
        )
for axs_index in range(2):
    axs[axs_index].hexbin(
        umap_isolates[axs_index][:, 0], umap_isolates[axs_index][:, 1], 
        cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax,
    )
    axs[axs_index].set_xlim(-22, 22)
    axs[axs_index].set_ylim(-22, 22)
    axs[axs_index].set_title(
        "su003" if axs_index == 0 else "su007", 
        fontweight='bold', fontsize=fontsize
    )
fig.tight_layout()
fig.show()

# hexbin, split by treatment
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
umap_treatments = [None, None]
for i in range(len(runs)):
    axs_index = profile[profile["Run"] == runs[i]]["Treatment"].values[0]
    if umap_treatments[axs_index] is None:
        umap_treatments[axs_index] = umap[runs[i]]
    else:
        umap_treatments[axs_index] = np.concatenate(
            [umap_treatments[axs_index], umap[runs[i]]]
        )
for axs_index in range(2):
    axs[axs_index].hexbin(
        umap_treatments[axs_index][:, 0], umap_treatments[axs_index][:, 1], 
        cmap="Reds", gridsize=gridsize,
        vmin = vmin if axs_index == 0 else (vmin/2),
        vmax = vmax if axs_index == 0 else (vmax/2)
    )
    axs[axs_index].set_xlim(-22, 22)
    axs[axs_index].set_ylim(-22, 22)
    axs[axs_index].set_title(
        "Pre" if axs_index == 0 else "Post", 
        fontweight='bold', fontsize=fontsize
    )
fig.tight_layout()
fig.show()
"""