In [2]:
method = "pretrain"     # pretrain or finetune
isolates = [3, 7]
chromosome = "6"
pval_thresh = 1e-5
part = "top"            # top, low
k_reads = 4000          # 1000, 4000

In [7]:
import torch

import numpy as np
import pandas as pd
import scipy.stats
import sklearn.decomposition

import os
import seaborn as sns
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
import matplotlib.patches

""" Profile """

# dataset
profile = pd.read_csv("data/stanford/profile.txt")
# Isolate string (i.e. su001) -> int (i.e. 1)
profile["Isolate"] = profile["Isolate"].apply(lambda x: int(x[2:]))
# Treatment pre/post -> 0/1
profile["Treatment"] = profile["Treatment"].apply(lambda x: int(not "pre" in x))
# Sort by Isolate (1 to 8), Treatment (pre to post), and Tissue (normal to BCC)
profile = profile.sort_values(
    by=["Isolate", "Treatment", "Tissue"], ascending=[True, True, False]
).reset_index(drop=True)
# only keep run, isolate, treatment, tissue
profile = profile[["Run", "Isolate", "Treatment", "Tissue"]]

""" Functions """

def euclideanDistance(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
    x1norm = np.sum(x1**2, axis=1, keepdims=True)           # (N1,  1)
    x2norm = np.sum(x2**2, axis=1, keepdims=True).T         # ( 1, N2)
    x1x2   = 2 * np.dot(x1, x2.T)                           # (N1, N2)
    return np.sqrt(np.maximum(0, x1norm + x2norm - x1x2))   # (N1, N2)

def pearsonCorrelation(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
    x1centered = x1 - x1.mean(axis=1, keepdims=True)        # (N1, 768)
    x2centered = x2 - x2.mean(axis=1, keepdims=True)        # (N2, 768)
    numerator = np.dot(x1centered, x2centered.T)            # (N1, N2)
    x1var = np.sum(x1centered**2, axis=1, keepdims=True)    # (N1,  1)
    x2var = np.sum(x2centered**2, axis=1, keepdims=True)    # (N2,  1)
    denominator = np.sqrt(np.dot(x1var, x2var.T))           # (N1, N2)
    return numerator / denominator                          # (N1, N2)

In [4]:
""" embd, umap, filter """

runs = []
for isolate in isolates:
    runs += profile[profile["Isolate"] == isolate]["Run"].to_list()

embd, umap = {}, {}
for run in runs:
    # load the embedding
    embd[run] = np.load(f"data/stanford/embd-{method}/{run}/{chromosome}.npy")
    # filter by p-value
    embd[run] = embd[run][embd[run][:, 769-int(np.log10(pval_thresh))]>=1, :]
    # only keep the first 768 columns, embedding
    embd[run] = embd[run][:, :768]
    # read umap of given sample id and chromosome
    umap[run] = np.load(f"data/stanford/umap-{method}/{run}/{chromosome}.npy")
    # filter reads that cover at least one variants with p-value<=pval_thresh
    umap[run] = umap[run][umap[run][:, 3-int(np.log10(pval_thresh))]>=1, :]
    # only keep the first 2 columns as umap coordinates
    umap[run] = umap[run][:, :2]

filter = [[
    np.full(len(embd[runs[i]]), False, dtype=bool) for _ in range(len(runs))
] for i in range(len(runs))]
for i in range(len(runs)):
    filter[i][i] = np.full(len(embd[runs[i]]), True, dtype=bool)
    for j in range(i+1, len(runs)):
        x1embd, x2embd = embd[runs[i]], embd[runs[j]]
        distance_matrix = euclideanDistance(x1embd, x2embd)
        if part == "top":
            x1thresh = np.sort(np.max(distance_matrix, axis=1))[-k_reads]
            x2thresh = np.sort(np.max(distance_matrix, axis=0))[-k_reads]
            filter[i][j] = np.max(distance_matrix, axis=1) >= x1thresh  # (N1, )
            filter[j][i] = np.max(distance_matrix, axis=0) >= x2thresh  # (N2, )
        elif part == "low":
            x1thresh = np.sort(np.max(distance_matrix, axis=1))[k_reads]
            x2thresh = np.sort(np.max(distance_matrix, axis=0))[k_reads]
            filter[i][j] = np.max(distance_matrix, axis=1) < x1thresh  # (N1, )
            filter[j][i] = np.max(distance_matrix, axis=0) < x2thresh  # (N2, )

In [None]:
""" Plot Distribution of Euclidean Distance """

bins = 100
fontsize = 18

x1embd, x2embd = embd[runs[0]], embd[runs[1]]   # (N1, 768), (N2, 768)
distance_matrix = euclideanDistance(x1embd, x2embd)     # (N1, N2)
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
for i in range(2):
    axs[i].hist(np.max(distance_matrix, axis=int(not i)), bins=bins)
    # title
    id, isolate, treatment, tissue = profile[profile["Run"] == runs[i]].iloc[0]
    treatment = "Pre" if treatment == 0 else "Post"
    tissue = "Normal" if "normal" in tissue else "Tumor"
    info = f"{id} - {isolate} - {treatment} - {tissue}"
    axs[i].set_title(f"{info}", fontweight='bold', fontsize=fontsize)
    # axis
    axs[i].set_xlim(5, 11)
fig.supxlabel("Euclidean Distance", fontweight='bold', y=0.004, fontsize=fontsize)
fig.supylabel("Number of Reads", fontweight='bold', x=0.008, fontsize=fontsize)
fig.tight_layout()
fig.show()

In [None]:
""" Plot Hexbin N * N+1"""

gridsize = 80
vmin, vmax = 4, 64
fontsize = 18

fig, axs = plt.subplots(
    len(runs), len(runs)+1, figsize=(35, 30), 
    sharex=True, sharey=True
)
# plot
for i in range(len(runs)):
    # plot umpa of sample i before filter in diagonal
    axs[i, i].hexbin(
        umap[runs[i]][:, 0], umap[runs[i]][:, 1], 
        cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax,
    )
    # plot umap of (i, :) and (:, i) after filter by distance (i, j)
    for j in range(i+1, len(runs)):
        # umap of (i, :)
        axs[i, j].hexbin(
            umap[runs[i]][filter[i][j], 0], umap[runs[i]][filter[i][j], 1], 
            cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax/2,
        )
        # umap of (:, i)
        axs[j, i].hexbin(
            umap[runs[j]][filter[j][i], 0], umap[runs[j]][filter[j][i], 1], 
            cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax/2,
        )
    # plot umap of (i, len(runs)) of combine all filter
    axs[i, len(runs)].hexbin(
        umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), 0], 
        umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), 1], 
        cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax/2,
    )
# set up axis and label
for i in range(len(runs)):
    id, isolate, treatment, tissue = profile[profile["Run"] == runs[i]].iloc[0]
    treatment = "Pre" if treatment == 0 else "Post"
    tissue = "Normal" if "normal" in tissue else "Tumor"
    info = f"{id} - {isolate} - {treatment} - {tissue}"
    axs[0,  i].set_xlabel(info, fontsize=fontsize, fontweight='bold')
    axs[0,  i].xaxis.set_label_position("top")
    axs[0, len(runs)].set_xlabel("All Selected Reads", fontsize=fontsize, fontweight='bold')
    axs[0, len(runs)].xaxis.set_label_position("top")
    axs[i, -1].set_ylabel(info, fontsize=fontsize, fontweight='bold')
    axs[i, -1].yaxis.set_label_position("right")
    for j in range(len(runs)+1):
        axs[i, j].set_aspect("equal")
        axs[i, j].set_xlim(-22, 22)
        axs[i, j].set_ylim(-22, 22)
fig.supxlabel("UMAP1", fontsize=fontsize, fontweight='bold', y=0.004)
fig.supylabel("UMAP2", fontsize=fontsize, fontweight='bold', x=0.008)
fig.tight_layout()
fig.savefig("temp.png", dpi=500)
fig.show()

In [None]:
""" Plot Scatter of all samples """

s=0.4

fig, axs = plt.subplots(1, 3, figsize=(30, 10), sharex=True, sharey=True)
# scatter, color by run
for i in range(len(runs)):
    axs[0].scatter(
        umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), 0], 
        umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), 1], 
        label=f"{runs[i]}", s=0.5
    )
axs[0].set_xlim(-22, 22)
axs[0].set_ylim(-22, 22)
axs[0].legend()
# scatter, color by isolate
for i in range(len(runs)):
    axs[1].scatter(
        umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), 0], 
        umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), 1], 
        c="red" if profile[profile["Run"] == runs[i]]["Isolate"].values[0] == 3 else "blue",
        label="su00{}".format(profile[profile["Run"] == runs[i]]["Isolate"].values[0]),
        s=s
    )
axs[1].set_xlim(-22, 22)
axs[1].set_ylim(-22, 22)
axs[1].legend(
    handles=[
        matplotlib.patches.Patch(color="r", label="su003"), 
        matplotlib.patches.Patch(color="b", label="su007")
    ], loc="upper right", ncol=1, fontsize=14,
)
# scatter, color by treatment
for i in range(len(runs)):
    axs[2].scatter(
        umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), 0], 
        umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), 1], 
        c="red" if profile[profile["Run"] == runs[i]]["Treatment"].values[0] == 0 else "blue",
        label="Pre" if profile[profile["Run"] == runs[i]]["Treatment"].values[0] == 0 else "Post",
        s=s
    )
axs[2].set_xlim(-22, 22)
axs[2].set_ylim(-22, 22)
axs[2].legend(
    handles=[
        matplotlib.patches.Patch(color="r", label="pre "), 
        matplotlib.patches.Patch(color="b", label="post")
    ], loc="upper right", ncol=1, fontsize=14,
)
fig.tight_layout()
fig.show()

In [None]:
""" Plot Hexbin of all samples """

gridsize = 80
vmin, vmax = 4, 64
fontsize = 18

# hexbin, split by isolate
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
umap_isolates = [None, None]
for i in range(len(runs)):
    axs_index = 0 if profile[profile["Run"] == runs[i]]["Isolate"].values[0] == 3 else 1
    if umap_isolates[axs_index] is None:
        umap_isolates[axs_index] = umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), :]
    else:
        umap_isolates[axs_index] = np.concatenate(
            [umap_isolates[axs_index], umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), :]]
        )
for axs_index in range(2):
    axs[axs_index].hexbin(
        umap_isolates[axs_index][:, 0], umap_isolates[axs_index][:, 1], 
        cmap="Reds", gridsize=gridsize, vmin=vmin, vmax=vmax,
    )
    axs[axs_index].set_xlim(-22, 22)
    axs[axs_index].set_ylim(-22, 22)
    axs[axs_index].set_title(
        "su003" if axs_index == 0 else "su007", 
        fontweight='bold', fontsize=fontsize
    )
fig.tight_layout()
fig.show()

# hexbin, split by treatment
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
umap_treatments = [None, None]
for i in range(len(runs)):
    axs_index = profile[profile["Run"] == runs[i]]["Treatment"].values[0]
    if umap_treatments[axs_index] is None:
        umap_treatments[axs_index] = umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), :]
    else:
        umap_treatments[axs_index] = np.concatenate(
            [umap_treatments[axs_index], umap[runs[i]][np.logical_or.reduce(filter[i][0:i] + filter[i][i+1:]), :]]
        )
for axs_index in range(2):
    axs[axs_index].hexbin(
        umap_treatments[axs_index][:, 0], umap_treatments[axs_index][:, 1], 
        cmap="Reds", gridsize=gridsize,
        vmin = vmin if axs_index == 0 else (vmin/2),
        vmax = vmax if axs_index == 0 else (vmax/2)
    )
    axs[axs_index].set_xlim(-22, 22)
    axs[axs_index].set_ylim(-22, 22)
    axs[axs_index].set_title(
        "Pre" if axs_index == 0 else "Post", 
        fontweight='bold', fontsize=fontsize
    )
fig.tight_layout()
fig.show()